-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BYOC][TENSOORT] Add support for FP16 on TensorRT BYOC flow #10388
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, in addition to support FP16 for TRT, this PR attempts to deprecate the support of TRT < 7.0.0? Since we don't have TRT runtime in CI, I have no clue how it affects existing use cases. If so, this would be a more important change and needs to be discussed and documented.
@@ -150,19 +165,30 @@ void TensorRTBuilder::AddLayer(int nid, const JSONGraphNode& node) { | |||
// Get outputs. | |||
node_output_map_[nid] = {}; | |||
for (auto out : params.outputs) { | |||
VLOG(1) << "Before forcing output tensor type: " << static_cast<int>(out->getType()) | |||
<< std::endl; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need a newline for log.
// According to documentation this is required for single FP precision. Always on doesnt seem to | ||
// prevent pure FP32 execution |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Better to provide the document link.
node_output_map_[nid].push_back(TensorRTOpInput(out)); | ||
VLOG(1) << "After forcing output tensor type: " << static_cast<int>(out->getType()) | ||
<< std::endl; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
// Pass it explicitly | ||
// config_->setFlag(nvinfer1::BuilderFlag::kDEBUG); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove?
@@ -204,19 +227,30 @@ TensorRTEngineAndContext TensorRTBuilder::BuildEngine() { | |||
|
|||
nvinfer1::Weights TensorRTBuilder::GetDLTensorAsWeights(const DLTensor* dptr, | |||
DLDeviceType src_device) { | |||
VLOG(1) << "Device type for DLTensorAsWeight: " << dptr->device.device_type; | |||
VLOG(1) << "DLType for DLTensorAsWeight: " << dptr->dtype; | |||
VLOG(1) << "DLShape for DLTensorAsWeight: " << dptr->shape << std::endl; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
@@ -169,50 +189,54 @@ def compile_and_run(mod, params, i_data, mode="vm", use_trt=True): | |||
mod, params, i_data, mode=mode, use_trt=use_trt | |||
) | |||
|
|||
print(result_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
|
||
if run_module: | ||
assert_result_dict_holds(result_dict) | ||
print(result_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
# run_and_verify_func( | ||
# get_graph((1, 3, 16, 16), (1, 3, 1, 1), channels=1), run_module=run_module) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Uncomment?
@@ -471,7 +502,8 @@ def get_graph( | |||
f = relay.Function([x], out) | |||
return f, {"x": x_shape}, [] | |||
|
|||
run_and_verify_func(get_graph(), run_module=run_module) | |||
# for tp in ["float32", "float16", "int8", "uint8"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove or?
# run_and_verify_func(get_graph((1, 1000), axis=-1), run_module=run_module) | ||
# run_and_verify_func(get_graph((1, 3, 4), axis=-2), run_module=run_module) | ||
# run_and_verify_func(get_graph((1, 3, 4), axis=1), run_module=run_module) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uncomment
81c53f4
to
e36ceb0
Compare
I revert all of the versioning changes and just kept it focused on the fp16 support. Thanks for the review PTAL. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks. Just nits.
6a6640e
to
5bdd0ed
Compare
ICHECK(TypeMatch(dtypes[i], kDLFloat, 32)) << "Only FP32 inputs are supported."; | ||
auto input_tensor = network_->addInput(name.c_str(), nvinfer1::DataType::kFLOAT, dims); | ||
auto tensor_dtype = | ||
(dtypes[i].bits == 16) ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd suggest ICHECK failing if unsupported type.
@@ -202,9 +211,6 @@ def _func_wrapper(expr): | |||
# ops with dynamic shapes are offloaded to VM | |||
if check_dynamism(args, op_name): | |||
return False | |||
if any([x.checked_type.dtype != "float32" for x in args]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not seeing where the type check (which must now be generalized to float32/float16) has gone too. If we remove it altogether then I think we'll either generate bad code or fail at trt build time, which from the tvm users point of view is runtime and too late. We also need to check in the predicate to prevent collage from exploring invalid candidate kernels.
|
||
// Convert op to TRT. | ||
converter->Convert(¶ms); | ||
|
||
// Get outputs. | ||
node_output_map_[nid] = {}; | ||
for (auto out : params.outputs) { | ||
auto out_type = params.inputs.at(1).weight.type == params.inputs.at(0).tensor->getType() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain this? It seems very specific yet AddLayer is used for all of the supported ops.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is unfortunately causing an vector index exception for me. I believe we need to pick up the output type from the node's dtype vector.
? nvinfer1::DataType::kFLOAT | ||
: nvinfer1::DataType::kINT32; | ||
|
||
const auto trt_dtype = (static_cast<int>(dptr->dtype.bits) == 16) ? nvinfer1::DataType::kHALF |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another ICHECK would be in order to make sure we're not silently generating bad code.
@@ -250,7 +253,7 @@ void TensorRTBuilder::CleanUp() { | |||
#endif | |||
builder_->destroy(); | |||
for (auto weight : trt_weights_) { | |||
if (weight.type == nvinfer1::DataType::kFLOAT) { | |||
if (static_cast<int>(weight.type) <= 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we avoid hard coding the enum constants?
e7405a9
to
2eb104b
Compare
3a3e1e4
to
0741642
Compare
0741642
to
d0e508b
Compare
@@ -85,8 +85,13 @@ void TensorRTBuilder::AddInput(int nid, uint32_t entry_id, const JSONGraphNode& | |||
shape.erase(shape.begin()); | |||
} | |||
nvinfer1::Dims dims = VectorToTrtDims(shape); | |||
ICHECK(TypeMatch(dtypes[i], kDLFloat, 32)) << "Only FP32 inputs are supported."; | |||
auto input_tensor = network_->addInput(name.c_str(), nvinfer1::DataType::kFLOAT, dims); | |||
ICHECK((dtypes[i].bits != 16 || dtypes[i].bits != 32)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is always true, I think you mean bits == 16 || bits == 32.
ret: bool | ||
True if supported, False if not. | ||
""" | ||
if any([x.checked_type.dtype in supported_types for x in args]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if all(...)
return True
log error
return False
…0388) * FP16 support for TRT * Cleanups on tests * Fix for typing on output tensor * Fix icheck * Add TRT inference builder auto-convert precision flags as attrs in the config * Address PR comments * Fix bug on passing the new config attrs to codegen for tensorrt partition Co-authored-by: Michalis Papapdimitriou <mpapapdimitriou@octoml.ai>
This PR enables support for FP16 types on the TensorRT BYOC flow.
Changes:
@mbs-octoml @electriclilies @masahi