Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorRT EP] Add new provider option to exclude nodes from running on TRT #22681

Merged
merged 20 commits into from
Nov 13, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,6 @@ struct OrtTensorRTProviderOptionsV2 {

const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix
int trt_engine_hw_compatible{0}; // Enable hardware compatibility. Default 0 = false, nonzero = true
const char* trt_nodes_to_exclude{nullptr}; // Exclude specific nodes from running on TRT e.g. "NonMaxSuppression,NonZero,RoiAlign".
chilo-ms marked this conversation as resolved.
Show resolved Hide resolved
// Adding '~' followed by a node e.g. '~NonZero', indicates that TRT EP will ensure this node is included for input to the TRT parser.
};
108 changes: 89 additions & 19 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1379,6 +1379,8 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
profile_opt_shapes = info.profile_opt_shapes;
cuda_graph_enable_ = info.cuda_graph_enable;
engine_hw_compatible_ = info.engine_hw_compatible;
nodes_to_exclude_ = info.nodes_to_exclude;

} else {
chilo-ms marked this conversation as resolved.
Show resolved Hide resolved
try {
const std::string max_partition_iterations_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kMaxPartitionIterations);
Expand Down Expand Up @@ -1565,6 +1567,11 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
cuda_graph_enable_ = (std::stoi(cuda_graph_enable_env) == 0 ? false : true);
}

const std::string nodes_to_exclude_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kNodesToExclude);
if (!nodes_to_exclude_env.empty()) {
nodes_to_exclude_ = nodes_to_exclude_env;
}

} catch (const std::invalid_argument& ex) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Invalid Argument (from environment variables): " << ex.what();
} catch (const std::out_of_range& ex) {
Expand Down Expand Up @@ -1725,6 +1732,10 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
runtime_ = std::unique_ptr<nvinfer1::IRuntime>(nvinfer1::createInferRuntime(GetTensorrtLogger(detailed_build_log_)));
}

trt_version_ = getInferLibVersion();

LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] TensorRT version is " << trt_version_;

LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] TensorRT provider options: "
<< "device_id: " << device_id_
<< ", trt_max_partition_iterations: " << max_partition_iterations_
Expand Down Expand Up @@ -1762,7 +1773,8 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
<< ", trt_ep_context_embed_mode: " << ep_context_embed_mode_
<< ", trt_cache_prefix: " << cache_prefix_
<< ", trt_engine_hw_compatible: " << engine_hw_compatible_
<< ", trt_onnx_model_bytestream_size_: " << onnx_model_bytestream_size_;
<< ", trt_onnx_model_bytestream_size_: " << onnx_model_bytestream_size_
<< ", trt_nodes_to_exclude: " << nodes_to_exclude_;
}

TensorrtExecutionProvider::~TensorrtExecutionProvider() {
Expand Down Expand Up @@ -2430,6 +2442,18 @@ bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t&
return cycle_detected;
}

std::set<std::string> GetExcludedNodeSet(std::string node_list_to_exclude) {
std::set<std::string> set;
if (!node_list_to_exclude.empty()) {
std::stringstream node_list(node_list_to_exclude);
std::string node;
while (std::getline(node_list, node, ',')) {
set.insert(node);
}
}
return set;
}

std::vector<std::unique_ptr<ComputeCapability>>
TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
const IKernelLookup& /*kernel_lookup*/) const {
Expand Down Expand Up @@ -2462,10 +2486,42 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
std::vector<size_t> nodes_vector(number_of_ort_nodes);
std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0);

std::vector<size_t> filtered_nodes_vector;
std::set<std::string> exclude_set = GetExcludedNodeSet(nodes_to_exclude_);

/*
* There is a known performance issue with the DDS nodes (NonMaxSuppression, NonZero and RoiAlign) from TRT versions 10.0 to 10.6.
* TRT EP automatically excludes DDS nodes from running on TRT unless the user explicitly specifies that those nodes should be included.
*
* Note: "~node_name" means to include the node.
chilo-ms marked this conversation as resolved.
Show resolved Hide resolved
*/
if (trt_version_ >= 100000 && trt_version_ < 100700) {
if (exclude_set.find("~NonMaxSuppression") == exclude_set.end()) exclude_set.insert("NonMaxSuppression");
if (exclude_set.find("~NonZero") == exclude_set.end()) exclude_set.insert("NonZero");
if (exclude_set.find("~RoiAlign") == exclude_set.end()) exclude_set.insert("RoiAlign");
}
chilo-ms marked this conversation as resolved.
Show resolved Hide resolved

// Print excluded nodes, if any.
std::set<std::string>::iterator it;
for (it = exclude_set.begin(); it != exclude_set.end(); ++it) {
std::string node = *it;
if (node.find("~") == 0) continue;
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Exclude " << node << " from running on TRT";
if (node == "NonMaxSuppression" || node == "NonZero" || node == "RoiAlign") {
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Add \"~" << node << "\" in trt_nodes_to_exclude if " << node << " should be included in the input to TRT parser. However, it still depends on TRT parser to determine the eligibility of this node for TRT";
}
}

SubGraphCollection_t parser_nodes_vector, supported_nodes_vector;
const std::vector<NodeIndex>& node_index = graph.GetNodesInTopologicalOrder(1 /*priority-based topological sort*/);
bool new_subgraph = true;

/* Iterate all the nodes and exclude the node if:
* 1. It's a control flow op and its subgraph(s) is not fully TRT eligible.
* 2. It's in the exlucded set which specified by trt_nodes_to_exclude.
chilo-ms marked this conversation as resolved.
Show resolved Hide resolved
*/
for (const auto& index : nodes_vector) {
const auto& node = graph.GetNode(node_index[index]);
bool supported_node = true;

/* If current node is control flow op, we take different approach based on following four cases:
*
Expand All @@ -2477,29 +2533,43 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
* For cases 2, 3, 4, even though the control flow op is not assigned to TRT, any portion of its subgraphs that can run in TRT will be still fused and assigned to TRT EP.
*/
if (control_flow_op_set_.find(node->OpType()) != control_flow_op_set_.end()) {
auto sub_graphs = node->GetSubgraphs();
if (sub_graphs.size() != 0) {
bool all_subgraphs_are_supported = true;
for (auto sub_graph : sub_graphs) {
// TRT EP should consider the empty subgraph is fully supported by TRT.
if (sub_graph->CreateGraphViewer()->NumberOfNodes() == 0) {
continue;
}
if (!AllNodesAssignedToSpecificEP(*(sub_graph->CreateGraphViewer()), kTensorrtExecutionProvider)) {
all_subgraphs_are_supported = false;
break;
auto supported_control_flow_op = [&](const Node* node) {
auto sub_graphs = node->GetSubgraphs();
if (sub_graphs.size() != 0) {
for (auto sub_graph : sub_graphs) {
// TRT EP should consider the empty subgraph is fully supported by TRT.
if (sub_graph->CreateGraphViewer()->NumberOfNodes() == 0) {
continue;
}
if (!AllNodesAssignedToSpecificEP(*(sub_graph->CreateGraphViewer()), kTensorrtExecutionProvider)) {
// if not all its subgraphs are supported, we need to exclude this control flow op
return false;
}
}
}
if (!all_subgraphs_are_supported) {
// if not all its subgraphs are supported, we need to exclude this control flow op
continue;
}
return true;
};
supported_node = supported_control_flow_op(node);
}

// Exclude any nodes, if applicable
if (exclude_set.find(node->OpType()) != exclude_set.end()) {
supported_node = false;
}

if (supported_node) {
if (new_subgraph) {
parser_nodes_vector.emplace_back();
// Mark all new graphs as "UnKnown" which will later be parsed by TRT parser
parser_nodes_vector.back().second = false;
new_subgraph = false;
}
parser_nodes_vector.back().first.emplace_back(index);
} else {
new_subgraph = true;
}
filtered_nodes_vector.push_back(index);
}

SubGraphCollection_t supported_nodes_vector, parser_nodes_vector = {{filtered_nodes_vector, false}};
bool early_termination = false;
supported_nodes_vector = GetSupportedList(parser_nodes_vector, 0, max_partition_iterations_, graph, &early_termination);
if (early_termination) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ static const std::string kDumpEpContextModel = "ORT_DUMP_EP_CONTEXT_MODEL";
static const std::string kEpContextEmbedMode = "ORT_EP_CONTEXT_EMBED_MODE";
static const std::string kEpContextComputeCapabilityEnable = "ORT_EP_CONTEXT_COMPUTE_CAPABILITY_ENABLE";
static const std::string kEngineCachePrefix = "ORT_TENSORRT_CACHE_PREFIX";
static const std::string kNodesToExclude = "ORT_TENSORRT_NODES_TO_EXCLUDE";
// Old env variable for backward compatibility
static const std::string kEngineCachePath = "ORT_TENSORRT_ENGINE_CACHE_PATH";
} // namespace tensorrt_env_vars
Expand Down Expand Up @@ -329,6 +330,10 @@ class TensorrtExecutionProvider : public IExecutionProvider {
bool cuda_graph_enable_ = false;
std::string cache_prefix_;
bool engine_hw_compatible_ = false;
std::string nodes_to_exclude_;

// The format is as for TENSORRT_VERSION: (MAJOR * 100 + MINOR) * 100 + PATCH
int32_t trt_version_;

chilo-ms marked this conversation as resolved.
Show resolved Hide resolved
// The OrtAllocator object will be get during ep compute time
// and should be kept for the lifetime of TRT EP object.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ constexpr const char* kDumpEpContextModel = "trt_dump_ep_context_model";
constexpr const char* kEngineHwCompatible = "trt_engine_hw_compatible";
constexpr const char* kONNXBytestream = "trt_onnx_bytestream";
constexpr const char* kONNXBytestreamSize = "trt_onnx_bytestream_size";
constexpr const char* kNodesToExclude = "trt_nodes_to_exclude";

} // namespace provider_option_names
} // namespace tensorrt
Expand Down Expand Up @@ -134,6 +135,7 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions
return Status::OK();
})
.AddAssignmentToReference(tensorrt::provider_option_names::kONNXBytestreamSize, info.onnx_bytestream_size)
.AddAssignmentToReference(tensorrt::provider_option_names::kNodesToExclude, info.nodes_to_exclude)
.Parse(options)); // add new provider option here.

info.user_compute_stream = user_compute_stream;
Expand Down Expand Up @@ -188,6 +190,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE
{tensorrt::provider_option_names::kEngineHwCompatible, MakeStringWithClassicLocale(info.engine_hw_compatible)},
{tensorrt::provider_option_names::kONNXBytestream, MakeStringWithClassicLocale(info.onnx_bytestream)},
{tensorrt::provider_option_names::kONNXBytestreamSize, MakeStringWithClassicLocale(info.onnx_bytestream_size)},
{tensorrt::provider_option_names::kNodesToExclude, MakeStringWithClassicLocale(info.nodes_to_exclude)},
};
return options;
}
Expand All @@ -206,6 +209,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor
const std::string kProfilesOptShapes_ = empty_if_null(info.trt_profile_opt_shapes);
const std::string kEpContextFilePath_ = empty_if_null(info.trt_ep_context_file_path);
const std::string kOnnxModelFolderPath_ = empty_if_null(info.trt_onnx_model_folder_path);
const std::string kNodesToExclude_ = empty_if_null(info.trt_nodes_to_exclude);

const ProviderOptions options{
{tensorrt::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
Expand Down Expand Up @@ -251,6 +255,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor
{tensorrt::provider_option_names::kEngineHwCompatible, MakeStringWithClassicLocale(info.trt_engine_hw_compatible)},
{tensorrt::provider_option_names::kONNXBytestream, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.trt_onnx_bytestream))},
{tensorrt::provider_option_names::kONNXBytestreamSize, MakeStringWithClassicLocale(info.trt_onnx_bytestream_size)},
{tensorrt::provider_option_names::kNodesToExclude, kNodesToExclude_},
};
return options;
}
Expand Down Expand Up @@ -355,5 +360,6 @@ void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options
trt_provider_options_v2.trt_engine_hw_compatible = internal_options.engine_hw_compatible;
trt_provider_options_v2.trt_onnx_bytestream = internal_options.onnx_bytestream;
trt_provider_options_v2.trt_onnx_bytestream_size = internal_options.onnx_bytestream_size;
trt_provider_options_v2.trt_nodes_to_exclude = copy_string_if_needed(internal_options.nodes_to_exclude);
}
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ struct TensorrtExecutionProviderInfo {
int ep_context_embed_mode{0};
std::string engine_cache_prefix{""};
bool engine_hw_compatible{false};
std::string nodes_to_exclude{""};

static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
static ProviderOptions ToProviderOptions(const TensorrtExecutionProviderInfo& info);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ struct Tensorrt_Provider : Provider {
info.engine_hw_compatible = options.trt_engine_hw_compatible != 0;
info.onnx_bytestream = options.trt_onnx_bytestream;
info.onnx_bytestream_size = options.trt_onnx_bytestream_size;
info.nodes_to_exclude = options.trt_nodes_to_exclude==nullptr ? "" : options.trt_nodes_to_exclude;

chilo-ms marked this conversation as resolved.
Show resolved Hide resolved
return std::make_shared<TensorrtProviderFactory>(info);
}
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2409,6 +2409,7 @@ ORT_API(void, OrtApis::ReleaseTensorRTProviderOptions, _Frees_ptr_opt_ OrtTensor
delete[] ptr->trt_profile_opt_shapes;
delete[] ptr->trt_ep_context_file_path;
delete[] ptr->trt_onnx_model_folder_path;
delete[] ptr->trt_nodes_to_exclude;
}

std::unique_ptr<OrtTensorRTProviderOptionsV2> p(ptr);
Expand Down
Loading