Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BYOC][TENSOORT] Add support for FP16 on TensorRT BYOC flow #10388

Merged
merged 7 commits into from
Mar 11, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 14 additions & 104 deletions python/tvm/relay/op/contrib/tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,6 @@ def _func_wrapper(expr):
# ops with dynamic shapes are offloaded to VM
if check_dynamism(args, op_name):
return False
if any([x.checked_type.dtype != "float32" for x in args]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not seeing where the type check (which must now be generalized to float32/float16) has gone too. If we remove it altogether then I think we'll either generate bad code or fail at trt build time, which from the tvm users point of view is runtime and too late. We also need to check in the predicate to prevent collage from exploring invalid candidate kernels.

logger.info("Only float32 inputs are supported for TensorRT.")
return False
if op_name == "multiply":
shapes = [
[
Expand Down Expand Up @@ -325,9 +322,6 @@ def add_annotate_fn(expr): # pylint: disable=unused-variable
if get_tensorrt_use_implicit_batch_mode() and any([len(shape) < 1 for shape in shapes]):
return False

if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if (
not get_tensorrt_use_implicit_batch_mode()
and (isinstance(args[0], Constant) or isinstance(args[1], Constant))
Expand All @@ -347,9 +341,6 @@ def batch_norm_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.batch_norm is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if len(args[0].checked_type.shape) == 5 and get_tensorrt_version() < (6, 0, 1):
logger.info("nn.batch_norm: TensorRT 6.0.1 or higher is required for rank 5 inputs.")
return False
Expand All @@ -366,10 +357,7 @@ def batch_norm_annotate_fn(expr): # pylint: disable=unused-variable
def softmax_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.softmax is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
attrs = expr.attrs
if get_tensorrt_use_implicit_batch_mode() and int(attrs.axis) == 0:
logger.info("nn.softmax: can't modify batch dimension.")
return False
Expand All @@ -380,10 +368,7 @@ def softmax_annotate_fn(expr): # pylint: disable=unused-variable
def conv1d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.conv1d is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
attrs = expr.attrs
if attrs.data_layout != "NCW":
logger.info("nn.conv1d: data_layout is %s but must be NCW.", attrs.data_layout)
return False
Expand All @@ -397,10 +382,7 @@ def conv1d_annotate_fn(expr): # pylint: disable=unused-variable
def conv2d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.conv2d is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
attrs = expr.attrs
if attrs.data_layout != "NCHW":
logger.info("nn.conv2d: data_layout is %s but must be NCHW.", attrs.data_layout)
return False
Expand All @@ -418,9 +400,6 @@ def dense_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if dense is supported by TensorRT."""

args = expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
input_rank = len(args[0].checked_type.shape)
weight_rank = len(args[1].checked_type.shape)
if input_rank not in (2, 3, 4):
Expand All @@ -436,9 +415,6 @@ def dense_annotate_fn(expr): # pylint: disable=unused-variable
def batch_matmul_annotate_fn(expr):
"""Check if dense is supported by TensorRT."""

if any([x.checked_type.dtype != "float32" for x in expr.args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if get_tensorrt_use_implicit_batch_mode() and len(expr.args[0].checked_type.shape) != len(
expr.args[1].checked_type.shape
):
Expand All @@ -451,9 +427,6 @@ def batch_matmul_annotate_fn(expr):
def layer_norm_annotate_fn(expr):
"""Check if dense is supported by TensorRT."""

if any([x.checked_type.dtype != "float32" for x in expr.args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if get_tensorrt_use_implicit_batch_mode() and int(expr.attrs.axis) == 0:
logger.info("nn.layer_norm: requires use_implict_batch=False.")
return False
Expand All @@ -465,9 +438,6 @@ def bias_add_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.bias_add is supported by TensorRT."""

args = expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
input_rank = len(args[0].checked_type.shape)
if input_rank not in (2, 3, 4):
logger.info("nn.bias_add: input rank is %d but must be 2, 3 or 4.", input_rank)
Expand All @@ -479,10 +449,7 @@ def bias_add_annotate_fn(expr): # pylint: disable=unused-variable
def max_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.max_pool2d is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
attrs = expr.attrs
if attrs.layout != "NCHW":
logger.info("nn.max_pool2d: layout is %s but must be NCHW.", attrs.layout)
return False
Expand All @@ -496,10 +463,7 @@ def max_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
def avg_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.avg_pool2d is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
attrs = expr.attrs
if attrs.layout != "NCHW":
logger.info("nn.avg_pool2d: layout is %d but must be NCHW.", attrs.layout)
return False
Expand All @@ -526,10 +490,7 @@ def avg_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
def global_max_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.global_max_pool2d is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
attrs = expr.attrs
if attrs.layout != "NCHW":
logger.info("nn.global_max_pool2d: layout is %s but must be NCHW.", attrs.layout)
return False
Expand All @@ -540,10 +501,7 @@ def global_max_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
def global_avg_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.global_avg_pool2d is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
attrs = expr.attrs
if attrs.layout != "NCHW":
logger.info("nn.global_avg_pool2d: layout is %s but must be NCHW.", attrs.layout)
return False
Expand All @@ -554,10 +512,7 @@ def global_avg_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
def expand_dims_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if expand_dims is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
attrs = expr.attrs
if get_tensorrt_use_implicit_batch_mode() and int(attrs.axis) == 0:
logger.info("expand_dims: can't modify batch dimension.")
return False
Expand All @@ -568,10 +523,7 @@ def expand_dims_annotate_fn(expr): # pylint: disable=unused-variable
def squeeze_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if squeeze is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
attrs = expr.attrs
if not attrs.axis:
logger.info("squeeze: must explicitly set axis.")
return False
Expand All @@ -586,9 +538,6 @@ def concatenate_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if concatenate is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.dtype != "float32" for x in args[0].checked_type.fields]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if not get_tensorrt_use_implicit_batch_mode():
return True
if int(attrs.axis) == 0:
Expand All @@ -606,9 +555,6 @@ def concatenate_annotate_fn(expr): # pylint: disable=unused-variable
def split_annotate_fn(expr):
"""Check if split is supported by TensorRT."""

if any([x.checked_type.dtype != "float32" for x in expr.args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if get_tensorrt_use_implicit_batch_mode() and int(expr.attrs.axis) == 0:
logger.info("split: can't modify batch dimension.")
return False
Expand All @@ -619,10 +565,7 @@ def split_annotate_fn(expr):
def conv2d_transpose_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.conv2d_transpose is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
attrs = expr.attrs
if attrs.data_layout != "NCHW":
logger.info("nn.conv2d_transpose: data_layout is %s but must be NCHW.", attrs.data_layout)
return False
Expand All @@ -644,10 +587,7 @@ def conv2d_transpose_annotate_fn(expr): # pylint: disable=unused-variable
def transpose_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if transpose is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
attrs = expr.attrs
if get_tensorrt_use_implicit_batch_mode() and int(attrs.axes[0]) != 0:
logger.info("transpose: can't modify batch dimension.")
return False
Expand All @@ -658,10 +598,7 @@ def transpose_annotate_fn(expr): # pylint: disable=unused-variable
def layout_transform_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if layout_transform is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
attrs = expr.attrs
if (attrs.src_layout, attrs.dst_layout) not in [
("NCHW", "NHWC"),
("NHWC", "NCHW"),
Expand All @@ -679,9 +616,6 @@ def layout_transform_annotate_fn(expr): # pylint: disable=unused-variable
def reshape_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if reshape is supported by TensorRT."""
attrs, args = expr.attrs, expr.args
if args[0].checked_type.dtype != "float32":
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if any([x < -1 for x in map(int, attrs.newshape)]):
logger.info("reshape: new shape dims must be explicit.")
return False
Expand Down Expand Up @@ -740,9 +674,6 @@ def pad_annotate_fn(expr): # pylint: disable=unused-variable
pad_value = args[1]
assert isinstance(pad_value, relay.Constant)
pad_value = pad_value.data.numpy().item()
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if attrs.pad_mode != "constant":
logger.info("nn.pad: pad mode is %s but must be constant.", attrs.pad_mode)
return False
Expand All @@ -766,9 +697,6 @@ def strided_slice_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if strided_slice is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if args[0].checked_type.dtype != "float32":
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if not trt_version_annotate_fn((5, 1, 5))(attrs, args, "strided_slice"):
return False
if get_tensorrt_use_implicit_batch_mode():
Expand Down Expand Up @@ -813,10 +741,7 @@ def strided_slice_annotate_fn(expr): # pylint: disable=unused-variable
def adaptive_max_pool2d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.adaptive_max_pool2d is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
attrs = expr.attrs
if len(attrs.output_size) == 0 or any([size != 1 for size in map(int, attrs.output_size)]):
logger.info("nn.adaptive_max_pool2d: output size must be (1, 1).")
return False
Expand All @@ -827,10 +752,7 @@ def adaptive_max_pool2d_annotate_fn(expr): # pylint: disable=unused-variable
def adaptive_avg_pool2d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.adaptive_avg_pool2d is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
attrs = expr.attrs
if len(attrs.output_size) == 0 or any([size != 1 for size in map(int, attrs.output_size)]):
logger.info("nn.adaptive_avg_pool2d: output size must be (1, 1).")
return False
Expand All @@ -842,9 +764,6 @@ def conv3d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.conv3d is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.conv3d"):
return False
if attrs.data_layout != "NCDHW":
Expand All @@ -864,9 +783,6 @@ def max_pool_3d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.max_pool3d is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.max_pool3d"):
return False
if attrs.layout != "NCDHW":
Expand All @@ -880,9 +796,6 @@ def avg_pool_3d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.avg_pool3d is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.avg_pool3d"):
return False
if attrs.layout != "NCDHW":
Expand All @@ -896,9 +809,6 @@ def conv3d_transpose_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.conv3d_transpose is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.conv3d_transpose"):
return False
if attrs.data_layout != "NCDHW":
Expand Down
26 changes: 15 additions & 11 deletions src/runtime/contrib/tensorrt/tensorrt_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,10 @@ void TensorRTBuilder::AddInput(int nid, uint32_t entry_id, const JSONGraphNode&
shape.erase(shape.begin());
}
nvinfer1::Dims dims = VectorToTrtDims(shape);
ICHECK(TypeMatch(dtypes[i], kDLFloat, 32)) << "Only FP32 inputs are supported.";
auto input_tensor = network_->addInput(name.c_str(), nvinfer1::DataType::kFLOAT, dims);
auto tensor_dtype =
(dtypes[i].bits == 16) ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest ICHECK failing if unsupported type.


auto input_tensor = network_->addInput(name.c_str(), tensor_dtype, dims);
node_output_map_[nid].push_back(TensorRTOpInput(input_tensor));
network_input_names_.push_back(name);
entry_id_map_[name] = entry_id + i;
Expand Down Expand Up @@ -139,17 +141,21 @@ void TensorRTBuilder::AddLayer(int nid, const JSONGraphNode& node) {
<< " requires weights but got a tensor.";
}
}
VLOG(1) << "INT " << input.type;
mikepapadim marked this conversation as resolved.
Show resolved Hide resolved
params.inputs.push_back(input);
}
ICHECK(converter->variable_input_count || converter->input_types.size() == params.inputs.size())
<< "Op expected a different number of inputs.";

// Convert op to TRT.
converter->Convert(&params);

// Get outputs.
node_output_map_[nid] = {};
for (auto out : params.outputs) {
auto out_type = params.inputs.at(1).weight.type == params.inputs.at(0).tensor->getType()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain this? It seems very specific yet AddLayer is used for all of the supported ops.

Copy link
Contributor

@mbs-octoml mbs-octoml Mar 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unfortunately causing an vector index exception for me. I believe we need to pick up the output type from the node's dtype vector.

? params.inputs.at(0).tensor->getType()
: params.inputs.at(1).weight.type;
out->setType(out_type);

node_output_map_[nid].push_back(TensorRTOpInput(out));
}
}
Expand Down Expand Up @@ -205,18 +211,16 @@ TensorRTEngineAndContext TensorRTBuilder::BuildEngine() {
nvinfer1::Weights TensorRTBuilder::GetDLTensorAsWeights(const DLTensor* dptr,
DLDeviceType src_device) {
ICHECK_EQ(dptr->device.device_type, src_device);
ICHECK(static_cast<int>(dptr->dtype.code) == kDLFloat ||
static_cast<int>(dptr->dtype.code) == kDLInt);
const auto trt_dtype = static_cast<int>(dptr->dtype.code) == kDLFloat
? nvinfer1::DataType::kFLOAT
: nvinfer1::DataType::kINT32;

const auto trt_dtype = (static_cast<int>(dptr->dtype.bits) == 16) ? nvinfer1::DataType::kHALF
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another ICHECK would be in order to make sure we're not silently generating bad code.

: nvinfer1::DataType::kFLOAT;

const size_t weight_bytes = GetDataSize(*dptr);
nvinfer1::Weights weight{trt_dtype, nullptr, 0};
size_t count = 1;
for (tvm_index_t i = 0; i < dptr->ndim; ++i) {
count *= dptr->shape[i];
}
ICHECK_EQ(count * 4, weight_bytes);
weight.count = count;
weight.values = new float[count];
ICHECK_EQ(TVMArrayCopyToBytes(const_cast<DLTensor*>(dptr), const_cast<void*>(weight.values),
Expand Down Expand Up @@ -250,7 +254,7 @@ void TensorRTBuilder::CleanUp() {
#endif
builder_->destroy();
for (auto weight : trt_weights_) {
if (weight.type == nvinfer1::DataType::kFLOAT) {
if (static_cast<int>(weight.type) <= 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we avoid hard coding the enum constants?

delete[] static_cast<const float*>(weight.values);
} else {
delete[] static_cast<const uint16_t*>(weight.values);
Expand Down
Loading