Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorRT EP] Add new provider option to exclude nodes from running on TRT #22681

Merged
merged 20 commits into from
Nov 13, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,7 @@ struct OrtTensorRTProviderOptionsV2 {

const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix
int trt_engine_hw_compatible{0}; // Enable hardware compatibility. Default 0 = false, nonzero = true
const char* trt_op_types_to_exclude{"NonMaxSuppression,NonZero,RoiAlign"}; // Exclude specific ops from running on TRT.
chilo-ms marked this conversation as resolved.
Show resolved Hide resolved
// There is a known performance issue with the DDS ops (NonMaxSuppression, NonZero and RoiAlign) from TRT versions 10.0 to 10.7.
// TRT EP excludes DDS ops from running on TRT by default, user can override default value with empty string to include all ops.
};
93 changes: 74 additions & 19 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1379,6 +1379,8 @@
profile_opt_shapes = info.profile_opt_shapes;
cuda_graph_enable_ = info.cuda_graph_enable;
engine_hw_compatible_ = info.engine_hw_compatible;
op_types_to_exclude_ = info.op_types_to_exclude;

} else {
try {
const std::string max_partition_iterations_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kMaxPartitionIterations);
Expand Down Expand Up @@ -1565,6 +1567,11 @@
cuda_graph_enable_ = (std::stoi(cuda_graph_enable_env) == 0 ? false : true);
}

const std::string op_types_to_exclude_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kOpTypesToExclude);
if (!op_types_to_exclude_env.empty()) {
op_types_to_exclude_ = op_types_to_exclude_env;
}

} catch (const std::invalid_argument& ex) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Invalid Argument (from environment variables): " << ex.what();
} catch (const std::out_of_range& ex) {
Expand Down Expand Up @@ -1725,6 +1732,10 @@
runtime_ = std::unique_ptr<nvinfer1::IRuntime>(nvinfer1::createInferRuntime(GetTensorrtLogger(detailed_build_log_)));
}

trt_version_ = getInferLibVersion();

LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] TensorRT version is " << trt_version_;

LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] TensorRT provider options: "
<< "device_id: " << device_id_
<< ", trt_max_partition_iterations: " << max_partition_iterations_
Expand Down Expand Up @@ -1762,7 +1773,8 @@
<< ", trt_ep_context_embed_mode: " << ep_context_embed_mode_
<< ", trt_cache_prefix: " << cache_prefix_
<< ", trt_engine_hw_compatible: " << engine_hw_compatible_
<< ", trt_onnx_model_bytestream_size_: " << onnx_model_bytestream_size_;
<< ", trt_onnx_model_bytestream_size_: " << onnx_model_bytestream_size_
<< ", trt_op_types_to_exclude: " << op_types_to_exclude_;
}

TensorrtExecutionProvider::~TensorrtExecutionProvider() {
Expand Down Expand Up @@ -2430,6 +2442,18 @@
return cycle_detected;
}

std::set<std::string> GetExcludedNodeSet(std::string node_list_to_exclude) {
std::set<std::string> set;
if (!node_list_to_exclude.empty()) {
std::stringstream node_list(node_list_to_exclude);
std::string node;
while (std::getline(node_list, node, ',')) {
set.insert(node);
}
}
return set;
}

std::vector<std::unique_ptr<ComputeCapability>>
TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
const IKernelLookup& /*kernel_lookup*/) const {
Expand Down Expand Up @@ -2462,10 +2486,27 @@
std::vector<size_t> nodes_vector(number_of_ort_nodes);
std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0);

std::vector<size_t> filtered_nodes_vector;
std::set<std::string> exclude_set = GetExcludedNodeSet(op_types_to_exclude_);

// Print excluded nodes, if any.
std::set<std::string>::iterator it;

Check warning on line 2492 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <set> for set<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:2492: Add #include <set> for set<> [build/include_what_you_use] [4]
for (it = exclude_set.begin(); it != exclude_set.end(); ++it) {
std::string op = *it;
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Exclude \"" << op << "\" from running on TRT, if any.";
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Remove \"" << op << "\" from trt_op_types_to_exclude or specify trt_op_types_to_exclude with empty string to include the op in the input to TRT parser. However, it still depends on TRT parser to determine the eligibility of this op for TRT.";
}

SubGraphCollection_t parser_nodes_vector, supported_nodes_vector;
const std::vector<NodeIndex>& node_index = graph.GetNodesInTopologicalOrder(1 /*priority-based topological sort*/);
bool new_subgraph = true;

/* Iterate all the nodes and exclude the node if:
* 1. It's a control flow op and its subgraph(s) is not fully TRT eligible.
* 2. It's in the exlucded set which specified by trt_op_types_to_exclude.
*/
for (const auto& index : nodes_vector) {
const auto& node = graph.GetNode(node_index[index]);
bool supported_node = true;

/* If current node is control flow op, we take different approach based on following four cases:
*
Expand All @@ -2477,29 +2518,43 @@
* For cases 2, 3, 4, even though the control flow op is not assigned to TRT, any portion of its subgraphs that can run in TRT will be still fused and assigned to TRT EP.
*/
if (control_flow_op_set_.find(node->OpType()) != control_flow_op_set_.end()) {
auto sub_graphs = node->GetSubgraphs();
if (sub_graphs.size() != 0) {
bool all_subgraphs_are_supported = true;
for (auto sub_graph : sub_graphs) {
// TRT EP should consider the empty subgraph is fully supported by TRT.
if (sub_graph->CreateGraphViewer()->NumberOfNodes() == 0) {
continue;
}
if (!AllNodesAssignedToSpecificEP(*(sub_graph->CreateGraphViewer()), kTensorrtExecutionProvider)) {
all_subgraphs_are_supported = false;
break;
auto supported_control_flow_op = [&](const Node* node) {
auto sub_graphs = node->GetSubgraphs();
if (sub_graphs.size() != 0) {
for (auto sub_graph : sub_graphs) {
// TRT EP should consider the empty subgraph is fully supported by TRT.
if (sub_graph->CreateGraphViewer()->NumberOfNodes() == 0) {
continue;
}
if (!AllNodesAssignedToSpecificEP(*(sub_graph->CreateGraphViewer()), kTensorrtExecutionProvider)) {
// if not all its subgraphs are supported, we need to exclude this control flow op
return false;
}
}
}
if (!all_subgraphs_are_supported) {
// if not all its subgraphs are supported, we need to exclude this control flow op
continue;
}
return true;
};
supported_node = supported_control_flow_op(node);
}

// Exclude any ops, if applicable
if (exclude_set.find(node->OpType()) != exclude_set.end()) {
supported_node = false;
}

if (supported_node) {
if (new_subgraph) {
parser_nodes_vector.emplace_back();
// Mark all new graphs as "UnKnown" which will later be parsed by TRT parser
parser_nodes_vector.back().second = false;
new_subgraph = false;
}
parser_nodes_vector.back().first.emplace_back(index);
} else {
new_subgraph = true;
}
filtered_nodes_vector.push_back(index);
}

SubGraphCollection_t supported_nodes_vector, parser_nodes_vector = {{filtered_nodes_vector, false}};
bool early_termination = false;
supported_nodes_vector = GetSupportedList(parser_nodes_vector, 0, max_partition_iterations_, graph, &early_termination);
if (early_termination) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
static const std::string kEpContextEmbedMode = "ORT_EP_CONTEXT_EMBED_MODE";
static const std::string kEpContextComputeCapabilityEnable = "ORT_EP_CONTEXT_COMPUTE_CAPABILITY_ENABLE";
static const std::string kEngineCachePrefix = "ORT_TENSORRT_CACHE_PREFIX";
static const std::string kOpTypesToExclude = "ORT_TENSORRT_OP_TYPES_TO_EXCLUDE";

Check warning on line 60 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 For a static/global string constant, use a C style string instead: "static const char kOpTypesToExclude[]". [runtime/string] [4] Raw Output: onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h:60: For a static/global string constant, use a C style string instead: "static const char kOpTypesToExclude[]". [runtime/string] [4]
// Old env variable for backward compatibility
static const std::string kEngineCachePath = "ORT_TENSORRT_ENGINE_CACHE_PATH";
} // namespace tensorrt_env_vars
Expand Down Expand Up @@ -329,6 +330,10 @@
bool cuda_graph_enable_ = false;
std::string cache_prefix_;
bool engine_hw_compatible_ = false;
std::string op_types_to_exclude_;

// The format is as for TENSORRT_VERSION: (MAJOR * 100 + MINOR) * 100 + PATCH
int32_t trt_version_;

// The OrtAllocator object will be get during ep compute time
// and should be kept for the lifetime of TRT EP object.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ constexpr const char* kDumpEpContextModel = "trt_dump_ep_context_model";
constexpr const char* kEngineHwCompatible = "trt_engine_hw_compatible";
constexpr const char* kONNXBytestream = "trt_onnx_bytestream";
constexpr const char* kONNXBytestreamSize = "trt_onnx_bytestream_size";
constexpr const char* kOpTypesToExclude = "trt_op_types_to_exclude";

} // namespace provider_option_names
} // namespace tensorrt
Expand Down Expand Up @@ -134,6 +135,7 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions
return Status::OK();
})
.AddAssignmentToReference(tensorrt::provider_option_names::kONNXBytestreamSize, info.onnx_bytestream_size)
.AddAssignmentToReference(tensorrt::provider_option_names::kOpTypesToExclude, info.op_types_to_exclude)
.Parse(options)); // add new provider option here.

info.user_compute_stream = user_compute_stream;
Expand Down Expand Up @@ -188,6 +190,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE
{tensorrt::provider_option_names::kEngineHwCompatible, MakeStringWithClassicLocale(info.engine_hw_compatible)},
{tensorrt::provider_option_names::kONNXBytestream, MakeStringWithClassicLocale(info.onnx_bytestream)},
{tensorrt::provider_option_names::kONNXBytestreamSize, MakeStringWithClassicLocale(info.onnx_bytestream_size)},
{tensorrt::provider_option_names::kOpTypesToExclude, MakeStringWithClassicLocale(info.op_types_to_exclude)},
};
return options;
}
Expand All @@ -206,6 +209,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor
const std::string kProfilesOptShapes_ = empty_if_null(info.trt_profile_opt_shapes);
const std::string kEpContextFilePath_ = empty_if_null(info.trt_ep_context_file_path);
const std::string kOnnxModelFolderPath_ = empty_if_null(info.trt_onnx_model_folder_path);
const std::string kOpTypesToExclude_ = empty_if_null(info.trt_op_types_to_exclude);

const ProviderOptions options{
{tensorrt::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
Expand Down Expand Up @@ -251,6 +255,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor
{tensorrt::provider_option_names::kEngineHwCompatible, MakeStringWithClassicLocale(info.trt_engine_hw_compatible)},
{tensorrt::provider_option_names::kONNXBytestream, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.trt_onnx_bytestream))},
{tensorrt::provider_option_names::kONNXBytestreamSize, MakeStringWithClassicLocale(info.trt_onnx_bytestream_size)},
{tensorrt::provider_option_names::kOpTypesToExclude, kOpTypesToExclude_},
};
return options;
}
Expand Down Expand Up @@ -355,5 +360,6 @@ void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options
trt_provider_options_v2.trt_engine_hw_compatible = internal_options.engine_hw_compatible;
trt_provider_options_v2.trt_onnx_bytestream = internal_options.onnx_bytestream;
trt_provider_options_v2.trt_onnx_bytestream_size = internal_options.onnx_bytestream_size;
trt_provider_options_v2.trt_op_types_to_exclude = copy_string_if_needed(internal_options.op_types_to_exclude);
}
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@
int ep_context_embed_mode{0};
std::string engine_cache_prefix{""};
bool engine_hw_compatible{false};
// There is a known performance issue with the DDS ops (NonMaxSuppression, NonZero and RoiAlign) from TRT versions 10.0 to 10.7.
// TRT EP excludes DDS ops from running on TRT by default, user can override default value of trt_op_types_to_exclude with empty string to include all ops.
std::string op_types_to_exclude{"NonMaxSuppression,NonZero,RoiAlign"};

Check warning on line 65 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h:65: Add #include <string> for string [build/include_what_you_use] [4]

static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
static ProviderOptions ToProviderOptions(const TensorrtExecutionProviderInfo& info);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ struct Tensorrt_Provider : Provider {
info.engine_hw_compatible = options.trt_engine_hw_compatible != 0;
info.onnx_bytestream = options.trt_onnx_bytestream;
info.onnx_bytestream_size = options.trt_onnx_bytestream_size;
info.op_types_to_exclude = options.trt_op_types_to_exclude == nullptr ? "" : options.trt_op_types_to_exclude;

return std::make_shared<TensorrtProviderFactory>(info);
}
Expand Down
8 changes: 6 additions & 2 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2293,8 +2293,11 @@ ORT_API_STATUS_IMPL(OrtApis::UpdateTensorRTProviderOptions,
#ifdef USE_TENSORRT
onnxruntime::ProviderOptions provider_options_map;
for (size_t i = 0; i != num_keys; ++i) {
if (provider_options_keys[i] == nullptr || provider_options_keys[i][0] == '\0' ||
provider_options_values[i] == nullptr || provider_options_values[i][0] == '\0') {
// Don't allow key and value to be empty except the value of trt_op_types_to_exclude
if (provider_options_keys[i] == nullptr ||
provider_options_keys[i][0] == '\0' ||
(provider_options_values[i] == nullptr && strcmp("trt_op_types_to_exclude", provider_options_keys[i])) ||
(provider_options_values[i][0] == '\0' && strcmp("trt_op_types_to_exclude", provider_options_keys[i]))) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "key/value cannot be empty");
}

Expand Down Expand Up @@ -2409,6 +2412,7 @@ ORT_API(void, OrtApis::ReleaseTensorRTProviderOptions, _Frees_ptr_opt_ OrtTensor
delete[] ptr->trt_profile_opt_shapes;
delete[] ptr->trt_ep_context_file_path;
delete[] ptr->trt_onnx_model_folder_path;
if (!ptr->trt_op_types_to_exclude) delete[] ptr->trt_op_types_to_exclude;
}

std::unique_ptr<OrtTensorRTProviderOptionsV2> p(ptr);
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
// and TRT EP instance, so it won't be released.)
std::string calibration_table, cache_path, cache_prefix, timing_cache_path, lib_path, trt_tactic_sources,
trt_extra_plugin_lib_paths, min_profile, max_profile, opt_profile, ep_context_file_path,
onnx_model_folder_path;
onnx_model_folder_path, trt_op_types_to_exclude{"NonMaxSuppression,NonZero,RoiAlign"};
auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
OrtTensorRTProviderOptionsV2 params;
Expand Down Expand Up @@ -824,6 +824,9 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_hw_compatible' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_op_types_to_exclude") {
trt_op_types_to_exclude = option.second;
params.trt_op_types_to_exclude = trt_op_types_to_exclude.c_str();
} else {
ORT_THROW("Invalid TensorRT EP option: ", option.first);
}
Expand Down
60 changes: 60 additions & 0 deletions onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,66 @@ TEST(TensorrtExecutionProviderTest, EPContextNode) {
RunSession(session_object9, run_options, feeds, output_names, expected_dims_mul_m, expected_values_mul_m);
}

TEST(TensorrtExecutionProviderTest, ExcludeOpsTest) {
/* The mnist.onnx looks like this:
* Conv
* |
* Add
* .
* .
* |
* MaxPool
* |
* .
* .
* MaxPool
* |
* Reshape
* |
* MatMul
* .
* .
*
*/
PathString model_name = ORT_TSTR("testdata/mnist.onnx");
SessionOptions so;
so.session_logid = "TensorrtExecutionProviderExcludeOpsTest";
RunOptions run_options;
run_options.run_tag = so.session_logid;
InferenceSession session_object{so, GetEnvironment()};
auto cuda_provider = DefaultCudaExecutionProvider();
auto cpu_allocator = cuda_provider->CreatePreferredAllocators()[1];
std::vector<int64_t> dims_op_x = {1, 1, 28, 28};
std::vector<float> values_op_x(784, 1.0f); // 784=1*1*28*28
OrtValue ml_value_x;
CreateMLValue<float>(cpu_allocator, dims_op_x, values_op_x, &ml_value_x);
NameMLValMap feeds;
feeds.insert(std::make_pair("Input3", ml_value_x));

// prepare outputs
std::vector<std::string> output_names;
output_names.push_back("Plus214_Output_0");
std::vector<OrtValue> fetches;

RemoveCachesByType("./", ".engine");
OrtTensorRTProviderOptionsV2 params;
params.trt_engine_cache_enable = 1;
params.trt_op_types_to_exclude = "MaxPool";
std::unique_ptr<IExecutionProvider> execution_provider = TensorrtExecutionProviderWithOptions(&params);
EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK());
auto status = session_object.Load(model_name);
ASSERT_TRUE(status.IsOK());
status = session_object.Initialize();
ASSERT_TRUE(status.IsOK());
status = session_object.Run(run_options, feeds, output_names, &fetches);
ASSERT_TRUE(status.IsOK());

std::vector<fs::path> engine_files;
engine_files = GetCachesByType("./", ".engine");
// The whole graph should be partitioned into 3 TRT subgraphs and 2 cpu nodes
ASSERT_EQ(engine_files.size(), 3);
}

TEST(TensorrtExecutionProviderTest, TRTPluginsCustomOpTest) {
PathString model_name = ORT_TSTR("testdata/trt_plugin_custom_op_test.onnx");
SessionOptions so;
Expand Down
Loading