Skip to content

Commit

Permalink
[TensorRT EP] Add new provider option to exclude nodes from running o…
Browse files Browse the repository at this point in the history
…n TRT (#22681)

Add new provider option `trt_op_types_to_exclude`:
- User can provide op type list to be excluded from running on TRT
- e.g. `trt_op_types_to_exclude="MaxPool"`

There is a known performance issue with the DDS ops (NonMaxSuppression,
NonZero and RoiAlign) from TRT versions 10.0 to 10.7. TRT EP excludes
DDS ops from running on TRT by default, user can override default value
with empty string to include all ops.
  • Loading branch information
chilo-ms authored and guschmue committed Dec 2, 2024
1 parent 0a4a323 commit f1c3770
Show file tree
Hide file tree
Showing 9 changed files with 178 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,21 +71,23 @@ struct OrtTensorRTProviderOptionsV2 {
* directory by means of the "trt_onnx_model_folder_path" option.
*
*/
int trt_dump_ep_context_model{0}; // Dump EP context node model
const char* trt_ep_context_file_path{nullptr}; // Specify file name to dump EP context node model. Can be a path or a file name or a file name with path.
int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data
int trt_weight_stripped_engine_enable{0}; // Enable weight-stripped engine build. Default 0 = false,
// nonzero = true
const char* trt_onnx_model_folder_path{nullptr}; // Folder path relative to the current working directory for
// the ONNX model containing the weights (applicable only when
// the "trt_weight_stripped_engine_enable" option is enabled)
const void* trt_onnx_bytestream{nullptr}; // The byte stream of th original ONNX model containing the weights
// (applicable only when the "trt_weight_stripped_engine_enable"
// option is enabled)
// can be updated using: UpdateTensorRTProviderOptionsWithValue
size_t trt_onnx_bytestream_size{0}; // size of the byte stream provided as "trt_onnx_bytestream"
// can be updated using: UpdateTensorRTProviderOptionsWithValue

const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix
int trt_engine_hw_compatible{0}; // Enable hardware compatibility. Default 0 = false, nonzero = true
int trt_dump_ep_context_model{0}; // Dump EP context node model
const char* trt_ep_context_file_path{nullptr}; // Specify file name to dump EP context node model. Can be a path or a file name or a file name with path.
int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data
int trt_weight_stripped_engine_enable{0}; // Enable weight-stripped engine build. Default 0 = false,
// nonzero = true
const char* trt_onnx_model_folder_path{nullptr}; // Folder path relative to the current working directory for
// the ONNX model containing the weights (applicable only when
// the "trt_weight_stripped_engine_enable" option is enabled)
const void* trt_onnx_bytestream{nullptr}; // The byte stream of th original ONNX model containing the weights
// (applicable only when the "trt_weight_stripped_engine_enable"
// option is enabled)
// can be updated using: UpdateTensorRTProviderOptionsWithValue
size_t trt_onnx_bytestream_size{0}; // size of the byte stream provided as "trt_onnx_bytestream"
// can be updated using: UpdateTensorRTProviderOptionsWithValue
const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix
int trt_engine_hw_compatible{0}; // Enable hardware compatibility. Default 0 = false, nonzero = true
const char* trt_op_types_to_exclude{"NonMaxSuppression,NonZero,RoiAlign"}; // Exclude specific ops from running on TRT.
// There is a known performance issue with the DDS ops (NonMaxSuppression, NonZero and RoiAlign) from TRT versions 10.0 to 10.7.
// TRT EP excludes DDS ops from running on TRT by default, user can override default value with empty string to include all ops.
};
93 changes: 74 additions & 19 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1379,6 +1379,8 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
profile_opt_shapes = info.profile_opt_shapes;
cuda_graph_enable_ = info.cuda_graph_enable;
engine_hw_compatible_ = info.engine_hw_compatible;
op_types_to_exclude_ = info.op_types_to_exclude;

} else {
try {
const std::string max_partition_iterations_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kMaxPartitionIterations);
Expand Down Expand Up @@ -1565,6 +1567,11 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
cuda_graph_enable_ = (std::stoi(cuda_graph_enable_env) == 0 ? false : true);
}

const std::string op_types_to_exclude_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kOpTypesToExclude);
if (!op_types_to_exclude_env.empty()) {
op_types_to_exclude_ = op_types_to_exclude_env;
}

} catch (const std::invalid_argument& ex) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Invalid Argument (from environment variables): " << ex.what();
} catch (const std::out_of_range& ex) {
Expand Down Expand Up @@ -1725,6 +1732,10 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
runtime_ = std::unique_ptr<nvinfer1::IRuntime>(nvinfer1::createInferRuntime(GetTensorrtLogger(detailed_build_log_)));
}

trt_version_ = getInferLibVersion();

LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] TensorRT version is " << trt_version_;

LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] TensorRT provider options: "
<< "device_id: " << device_id_
<< ", trt_max_partition_iterations: " << max_partition_iterations_
Expand Down Expand Up @@ -1762,7 +1773,8 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
<< ", trt_ep_context_embed_mode: " << ep_context_embed_mode_
<< ", trt_cache_prefix: " << cache_prefix_
<< ", trt_engine_hw_compatible: " << engine_hw_compatible_
<< ", trt_onnx_model_bytestream_size_: " << onnx_model_bytestream_size_;
<< ", trt_onnx_model_bytestream_size_: " << onnx_model_bytestream_size_
<< ", trt_op_types_to_exclude: " << op_types_to_exclude_;
}

TensorrtExecutionProvider::~TensorrtExecutionProvider() {
Expand Down Expand Up @@ -2430,6 +2442,18 @@ bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t&
return cycle_detected;
}

std::set<std::string> GetExcludedNodeSet(std::string node_list_to_exclude) {
std::set<std::string> set;
if (!node_list_to_exclude.empty()) {
std::stringstream node_list(node_list_to_exclude);
std::string node;
while (std::getline(node_list, node, ',')) {
set.insert(node);
}
}
return set;
}

std::vector<std::unique_ptr<ComputeCapability>>
TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
const IKernelLookup& /*kernel_lookup*/) const {
Expand Down Expand Up @@ -2462,10 +2486,27 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
std::vector<size_t> nodes_vector(number_of_ort_nodes);
std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0);

std::vector<size_t> filtered_nodes_vector;
std::set<std::string> exclude_set = GetExcludedNodeSet(op_types_to_exclude_);

// Print excluded nodes, if any.
std::set<std::string>::iterator it;
for (it = exclude_set.begin(); it != exclude_set.end(); ++it) {
std::string op = *it;
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Exclude \"" << op << "\" from running on TRT, if any.";
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Remove \"" << op << "\" from trt_op_types_to_exclude or specify trt_op_types_to_exclude with empty string to include the op in the input to TRT parser. However, it still depends on TRT parser to determine the eligibility of this op for TRT.";
}

SubGraphCollection_t parser_nodes_vector, supported_nodes_vector;
const std::vector<NodeIndex>& node_index = graph.GetNodesInTopologicalOrder(1 /*priority-based topological sort*/);
bool new_subgraph = true;

/* Iterate all the nodes and exclude the node if:
* 1. It's a control flow op and its subgraph(s) is not fully TRT eligible.
* 2. It's in the exlucded set which specified by trt_op_types_to_exclude.
*/
for (const auto& index : nodes_vector) {
const auto& node = graph.GetNode(node_index[index]);
bool supported_node = true;

/* If current node is control flow op, we take different approach based on following four cases:
*
Expand All @@ -2477,29 +2518,43 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
* For cases 2, 3, 4, even though the control flow op is not assigned to TRT, any portion of its subgraphs that can run in TRT will be still fused and assigned to TRT EP.
*/
if (control_flow_op_set_.find(node->OpType()) != control_flow_op_set_.end()) {
auto sub_graphs = node->GetSubgraphs();
if (sub_graphs.size() != 0) {
bool all_subgraphs_are_supported = true;
for (auto sub_graph : sub_graphs) {
// TRT EP should consider the empty subgraph is fully supported by TRT.
if (sub_graph->CreateGraphViewer()->NumberOfNodes() == 0) {
continue;
}
if (!AllNodesAssignedToSpecificEP(*(sub_graph->CreateGraphViewer()), kTensorrtExecutionProvider)) {
all_subgraphs_are_supported = false;
break;
auto supported_control_flow_op = [&](const Node* node) {
auto sub_graphs = node->GetSubgraphs();
if (sub_graphs.size() != 0) {
for (auto sub_graph : sub_graphs) {
// TRT EP should consider the empty subgraph is fully supported by TRT.
if (sub_graph->CreateGraphViewer()->NumberOfNodes() == 0) {
continue;
}
if (!AllNodesAssignedToSpecificEP(*(sub_graph->CreateGraphViewer()), kTensorrtExecutionProvider)) {
// if not all its subgraphs are supported, we need to exclude this control flow op
return false;
}
}
}
if (!all_subgraphs_are_supported) {
// if not all its subgraphs are supported, we need to exclude this control flow op
continue;
}
return true;
};
supported_node = supported_control_flow_op(node);
}

// Exclude any ops, if applicable
if (exclude_set.find(node->OpType()) != exclude_set.end()) {
supported_node = false;
}

if (supported_node) {
if (new_subgraph) {
parser_nodes_vector.emplace_back();
// Mark all new graphs as "UnKnown" which will later be parsed by TRT parser
parser_nodes_vector.back().second = false;
new_subgraph = false;
}
parser_nodes_vector.back().first.emplace_back(index);
} else {
new_subgraph = true;
}
filtered_nodes_vector.push_back(index);
}

SubGraphCollection_t supported_nodes_vector, parser_nodes_vector = {{filtered_nodes_vector, false}};
bool early_termination = false;
supported_nodes_vector = GetSupportedList(parser_nodes_vector, 0, max_partition_iterations_, graph, &early_termination);
if (early_termination) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ static const std::string kDumpEpContextModel = "ORT_DUMP_EP_CONTEXT_MODEL";
static const std::string kEpContextEmbedMode = "ORT_EP_CONTEXT_EMBED_MODE";
static const std::string kEpContextComputeCapabilityEnable = "ORT_EP_CONTEXT_COMPUTE_CAPABILITY_ENABLE";
static const std::string kEngineCachePrefix = "ORT_TENSORRT_CACHE_PREFIX";
static const std::string kOpTypesToExclude = "ORT_TENSORRT_OP_TYPES_TO_EXCLUDE";
// Old env variable for backward compatibility
static const std::string kEngineCachePath = "ORT_TENSORRT_ENGINE_CACHE_PATH";
} // namespace tensorrt_env_vars
Expand Down Expand Up @@ -329,6 +330,10 @@ class TensorrtExecutionProvider : public IExecutionProvider {
bool cuda_graph_enable_ = false;
std::string cache_prefix_;
bool engine_hw_compatible_ = false;
std::string op_types_to_exclude_;

// The format is as for TENSORRT_VERSION: (MAJOR * 100 + MINOR) * 100 + PATCH
int32_t trt_version_;

// The OrtAllocator object will be get during ep compute time
// and should be kept for the lifetime of TRT EP object.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ constexpr const char* kDumpEpContextModel = "trt_dump_ep_context_model";
constexpr const char* kEngineHwCompatible = "trt_engine_hw_compatible";
constexpr const char* kONNXBytestream = "trt_onnx_bytestream";
constexpr const char* kONNXBytestreamSize = "trt_onnx_bytestream_size";
constexpr const char* kOpTypesToExclude = "trt_op_types_to_exclude";

} // namespace provider_option_names
} // namespace tensorrt
Expand Down Expand Up @@ -134,6 +135,7 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions
return Status::OK();
})
.AddAssignmentToReference(tensorrt::provider_option_names::kONNXBytestreamSize, info.onnx_bytestream_size)
.AddAssignmentToReference(tensorrt::provider_option_names::kOpTypesToExclude, info.op_types_to_exclude)
.Parse(options)); // add new provider option here.

info.user_compute_stream = user_compute_stream;
Expand Down Expand Up @@ -188,6 +190,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE
{tensorrt::provider_option_names::kEngineHwCompatible, MakeStringWithClassicLocale(info.engine_hw_compatible)},
{tensorrt::provider_option_names::kONNXBytestream, MakeStringWithClassicLocale(info.onnx_bytestream)},
{tensorrt::provider_option_names::kONNXBytestreamSize, MakeStringWithClassicLocale(info.onnx_bytestream_size)},
{tensorrt::provider_option_names::kOpTypesToExclude, MakeStringWithClassicLocale(info.op_types_to_exclude)},
};
return options;
}
Expand All @@ -206,6 +209,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor
const std::string kProfilesOptShapes_ = empty_if_null(info.trt_profile_opt_shapes);
const std::string kEpContextFilePath_ = empty_if_null(info.trt_ep_context_file_path);
const std::string kOnnxModelFolderPath_ = empty_if_null(info.trt_onnx_model_folder_path);
const std::string kOpTypesToExclude_ = empty_if_null(info.trt_op_types_to_exclude);

const ProviderOptions options{
{tensorrt::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
Expand Down Expand Up @@ -251,6 +255,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor
{tensorrt::provider_option_names::kEngineHwCompatible, MakeStringWithClassicLocale(info.trt_engine_hw_compatible)},
{tensorrt::provider_option_names::kONNXBytestream, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.trt_onnx_bytestream))},
{tensorrt::provider_option_names::kONNXBytestreamSize, MakeStringWithClassicLocale(info.trt_onnx_bytestream_size)},
{tensorrt::provider_option_names::kOpTypesToExclude, kOpTypesToExclude_},
};
return options;
}
Expand Down Expand Up @@ -355,5 +360,6 @@ void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options
trt_provider_options_v2.trt_engine_hw_compatible = internal_options.engine_hw_compatible;
trt_provider_options_v2.trt_onnx_bytestream = internal_options.onnx_bytestream;
trt_provider_options_v2.trt_onnx_bytestream_size = internal_options.onnx_bytestream_size;
trt_provider_options_v2.trt_op_types_to_exclude = copy_string_if_needed(internal_options.op_types_to_exclude);
}
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ struct TensorrtExecutionProviderInfo {
int ep_context_embed_mode{0};
std::string engine_cache_prefix{""};
bool engine_hw_compatible{false};
// There is a known performance issue with the DDS ops (NonMaxSuppression, NonZero and RoiAlign) from TRT versions 10.0 to 10.7.
// TRT EP excludes DDS ops from running on TRT by default, user can override default value of trt_op_types_to_exclude with empty string to include all ops.
std::string op_types_to_exclude{"NonMaxSuppression,NonZero,RoiAlign"};

static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
static ProviderOptions ToProviderOptions(const TensorrtExecutionProviderInfo& info);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ struct Tensorrt_Provider : Provider {
info.engine_hw_compatible = options.trt_engine_hw_compatible != 0;
info.onnx_bytestream = options.trt_onnx_bytestream;
info.onnx_bytestream_size = options.trt_onnx_bytestream_size;
info.op_types_to_exclude = options.trt_op_types_to_exclude == nullptr ? "" : options.trt_op_types_to_exclude;

return std::make_shared<TensorrtProviderFactory>(info);
}
Expand Down
8 changes: 6 additions & 2 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2294,8 +2294,11 @@ ORT_API_STATUS_IMPL(OrtApis::UpdateTensorRTProviderOptions,
#ifdef USE_TENSORRT
onnxruntime::ProviderOptions provider_options_map;
for (size_t i = 0; i != num_keys; ++i) {
if (provider_options_keys[i] == nullptr || provider_options_keys[i][0] == '\0' ||
provider_options_values[i] == nullptr || provider_options_values[i][0] == '\0') {
// Don't allow key and value to be empty except the value of trt_op_types_to_exclude
if (provider_options_keys[i] == nullptr ||
provider_options_keys[i][0] == '\0' ||
(provider_options_values[i] == nullptr && strcmp("trt_op_types_to_exclude", provider_options_keys[i])) ||
(provider_options_values[i][0] == '\0' && strcmp("trt_op_types_to_exclude", provider_options_keys[i]))) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "key/value cannot be empty");
}

Expand Down Expand Up @@ -2410,6 +2413,7 @@ ORT_API(void, OrtApis::ReleaseTensorRTProviderOptions, _Frees_ptr_opt_ OrtTensor
delete[] ptr->trt_profile_opt_shapes;
delete[] ptr->trt_ep_context_file_path;
delete[] ptr->trt_onnx_model_folder_path;
if (!ptr->trt_op_types_to_exclude) delete[] ptr->trt_op_types_to_exclude;
}

std::unique_ptr<OrtTensorRTProviderOptionsV2> p(ptr);
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
// and TRT EP instance, so it won't be released.)
std::string calibration_table, cache_path, cache_prefix, timing_cache_path, lib_path, trt_tactic_sources,
trt_extra_plugin_lib_paths, min_profile, max_profile, opt_profile, ep_context_file_path,
onnx_model_folder_path;
onnx_model_folder_path, trt_op_types_to_exclude{"NonMaxSuppression,NonZero,RoiAlign"};
auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
OrtTensorRTProviderOptionsV2 params;
Expand Down Expand Up @@ -824,6 +824,9 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_hw_compatible' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_op_types_to_exclude") {
trt_op_types_to_exclude = option.second;
params.trt_op_types_to_exclude = trt_op_types_to_exclude.c_str();
} else {
ORT_THROW("Invalid TensorRT EP option: ", option.first);
}
Expand Down
Loading

0 comments on commit f1c3770

Please sign in to comment.