Skip to content

Commit

Permalink
[TensorRT EP] Exclude DDS ops from running on TRT (#22875)
Browse files Browse the repository at this point in the history
TRT EP excludes DDS ops from running on TRT and doesn't allow users to
change.
This PR is for ORT 1.20.1 patch release.

We will have better solution to add a new provider option to exclude
specific ops, similar to following:
#22863
#22681
  • Loading branch information
chilo-ms authored and yf711 committed Nov 18, 2024
1 parent 04a9bb1 commit be4e2a5
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 18 deletions.
74 changes: 56 additions & 18 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1725,6 +1725,10 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
runtime_ = std::unique_ptr<nvinfer1::IRuntime>(nvinfer1::createInferRuntime(GetTensorrtLogger(detailed_build_log_)));
}

trt_version_ = getInferLibVersion();

LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] TensorRT version is " << trt_version_;

LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] TensorRT provider options: "
<< "device_id: " << device_id_
<< ", trt_max_partition_iterations: " << max_partition_iterations_
Expand Down Expand Up @@ -2462,10 +2466,30 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
std::vector<size_t> nodes_vector(number_of_ort_nodes);
std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0);

std::vector<size_t> filtered_nodes_vector;
std::set<std::string> exclude_ops_set;

Check warning on line 2469 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <set> for set<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:2469: Add #include <set> for set<> [build/include_what_you_use] [4]

/*
* There is a known performance issue with the DDS ops (NonMaxSuppression, NonZero and RoiAlign) in TRT 10.
* TRT EP automatically excludes DDS ops from running on TRT.
*/
if (trt_version_ >= 100000 && trt_version_ < 110000) {
exclude_ops_set.insert("NonMaxSuppression");
exclude_ops_set.insert("NonZero");
exclude_ops_set.insert("RoiAlign");
LOGS_DEFAULT(VERBOSE) << "There is a known performance issue with the DDS ops (NonMaxSuppression, NonZero and RoiAlign) in TRT 10. TRT EP automatically excludes DDS ops from running on TRT, if applicable";

Check warning on line 2479 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:2479: Lines should be <= 120 characters long [whitespace/line_length] [2]
}

SubGraphCollection_t parser_nodes_vector, supported_nodes_vector;
const std::vector<NodeIndex>& node_index = graph.GetNodesInTopologicalOrder(1 /*priority-based topological sort*/);
bool new_subgraph = true;

/* Iterate all the nodes and exclude the node if:
* 1. It's a control flow op and its subgraph(s) is not fully TRT eligible.
* 2. It's a DDS op.
*/
for (const auto& index : nodes_vector) {
const auto& node = graph.GetNode(node_index[index]);
bool supported_node = true;

/* If current node is control flow op, we take different approach based on following four cases:
*
Expand All @@ -2477,29 +2501,43 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
* For cases 2, 3, 4, even though the control flow op is not assigned to TRT, any portion of its subgraphs that can run in TRT will be still fused and assigned to TRT EP.
*/
if (control_flow_op_set_.find(node->OpType()) != control_flow_op_set_.end()) {
auto sub_graphs = node->GetSubgraphs();
if (sub_graphs.size() != 0) {
bool all_subgraphs_are_supported = true;
for (auto sub_graph : sub_graphs) {
// TRT EP should consider the empty subgraph is fully supported by TRT.
if (sub_graph->CreateGraphViewer()->NumberOfNodes() == 0) {
continue;
}
if (!AllNodesAssignedToSpecificEP(*(sub_graph->CreateGraphViewer()), kTensorrtExecutionProvider)) {
all_subgraphs_are_supported = false;
break;
auto supported_control_flow_op = [&](const Node* node) {
auto sub_graphs = node->GetSubgraphs();
if (sub_graphs.size() != 0) {
for (auto sub_graph : sub_graphs) {
// TRT EP should consider the empty subgraph is fully supported by TRT.
if (sub_graph->CreateGraphViewer()->NumberOfNodes() == 0) {
continue;
}
if (!AllNodesAssignedToSpecificEP(*(sub_graph->CreateGraphViewer()), kTensorrtExecutionProvider)) {
// if not all its subgraphs are supported, we need to exclude this control flow op
return false;
}
}
}
if (!all_subgraphs_are_supported) {
// if not all its subgraphs are supported, we need to exclude this control flow op
continue;
}
return true;
};
supported_node = supported_control_flow_op(node);
}

// Exclude any ops, if applicable
if (exclude_ops_set.find(node->OpType()) != exclude_ops_set.end()) {
supported_node = false;
}

if (supported_node) {
if (new_subgraph) {
parser_nodes_vector.emplace_back();
// Mark all new graphs as "UnKnown" which will later be parsed by TRT parser
parser_nodes_vector.back().second = false;
new_subgraph = false;
}
parser_nodes_vector.back().first.emplace_back(index);
} else {
new_subgraph = true;
}
filtered_nodes_vector.push_back(index);
}

SubGraphCollection_t supported_nodes_vector, parser_nodes_vector = {{filtered_nodes_vector, false}};
bool early_termination = false;
supported_nodes_vector = GetSupportedList(parser_nodes_vector, 0, max_partition_iterations_, graph, &early_termination);
if (early_termination) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,10 @@ class TensorrtExecutionProvider : public IExecutionProvider {
bool cuda_graph_enable_ = false;
std::string cache_prefix_;
bool engine_hw_compatible_ = false;
std::string op_types_to_exclude_;

// The format is as for TENSORRT_VERSION: (MAJOR * 100 + MINOR) * 100 + PATCH
int32_t trt_version_;

// The OrtAllocator object will be get during ep compute time
// and should be kept for the lifetime of TRT EP object.
Expand Down

0 comments on commit be4e2a5

Please sign in to comment.