Skip to content

Commit

Permalink
[TensorRT EP] Revert "Add new provider option to exclude nodes from r…
Browse files Browse the repository at this point in the history
…unning on TRT" (microsoft#22878)

- Revert microsoft#22681
- But still implicitly exclude DDS ops for TRT 10. Will later provide
better PR to add trt_op_types_to_exclude provider option.
  • Loading branch information
chilo-ms authored and ankitm3k committed Dec 11, 2024
1 parent aa221ee commit 9c89851
Show file tree
Hide file tree
Showing 9 changed files with 33 additions and 130 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,23 +71,21 @@ struct OrtTensorRTProviderOptionsV2 {
* directory by means of the "trt_onnx_model_folder_path" option.
*
*/
int trt_dump_ep_context_model{0}; // Dump EP context node model
const char* trt_ep_context_file_path{nullptr}; // Specify file name to dump EP context node model. Can be a path or a file name or a file name with path.
int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data
int trt_weight_stripped_engine_enable{0}; // Enable weight-stripped engine build. Default 0 = false,
// nonzero = true
const char* trt_onnx_model_folder_path{nullptr}; // Folder path relative to the current working directory for
// the ONNX model containing the weights (applicable only when
// the "trt_weight_stripped_engine_enable" option is enabled)
const void* trt_onnx_bytestream{nullptr}; // The byte stream of th original ONNX model containing the weights
// (applicable only when the "trt_weight_stripped_engine_enable"
// option is enabled)
// can be updated using: UpdateTensorRTProviderOptionsWithValue
size_t trt_onnx_bytestream_size{0}; // size of the byte stream provided as "trt_onnx_bytestream"
// can be updated using: UpdateTensorRTProviderOptionsWithValue
const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix
int trt_engine_hw_compatible{0}; // Enable hardware compatibility. Default 0 = false, nonzero = true
const char* trt_op_types_to_exclude{"NonMaxSuppression,NonZero,RoiAlign"}; // Exclude specific ops from running on TRT.
// There is a known performance issue with the DDS ops (NonMaxSuppression, NonZero and RoiAlign) from TRT versions 10.0 to 10.7.
// TRT EP excludes DDS ops from running on TRT by default, user can override default value with empty string to include all ops.
int trt_dump_ep_context_model{0}; // Dump EP context node model
const char* trt_ep_context_file_path{nullptr}; // Specify file name to dump EP context node model. Can be a path or a file name or a file name with path.
int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data
int trt_weight_stripped_engine_enable{0}; // Enable weight-stripped engine build. Default 0 = false,
// nonzero = true
const char* trt_onnx_model_folder_path{nullptr}; // Folder path relative to the current working directory for
// the ONNX model containing the weights (applicable only when
// the "trt_weight_stripped_engine_enable" option is enabled)
const void* trt_onnx_bytestream{nullptr}; // The byte stream of th original ONNX model containing the weights
// (applicable only when the "trt_weight_stripped_engine_enable"
// option is enabled)
// can be updated using: UpdateTensorRTProviderOptionsWithValue
size_t trt_onnx_bytestream_size{0}; // size of the byte stream provided as "trt_onnx_bytestream"
// can be updated using: UpdateTensorRTProviderOptionsWithValue

const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix
int trt_engine_hw_compatible{0}; // Enable hardware compatibility. Default 0 = false, nonzero = true
};
43 changes: 13 additions & 30 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1379,8 +1379,6 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
profile_opt_shapes = info.profile_opt_shapes;
cuda_graph_enable_ = info.cuda_graph_enable;
engine_hw_compatible_ = info.engine_hw_compatible;
op_types_to_exclude_ = info.op_types_to_exclude;

} else {
try {
const std::string max_partition_iterations_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kMaxPartitionIterations);
Expand Down Expand Up @@ -1567,11 +1565,6 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
cuda_graph_enable_ = (std::stoi(cuda_graph_enable_env) == 0 ? false : true);
}

const std::string op_types_to_exclude_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kOpTypesToExclude);
if (!op_types_to_exclude_env.empty()) {
op_types_to_exclude_ = op_types_to_exclude_env;
}

} catch (const std::invalid_argument& ex) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Invalid Argument (from environment variables): " << ex.what();
} catch (const std::out_of_range& ex) {
Expand Down Expand Up @@ -1773,8 +1766,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
<< ", trt_ep_context_embed_mode: " << ep_context_embed_mode_
<< ", trt_cache_prefix: " << cache_prefix_
<< ", trt_engine_hw_compatible: " << engine_hw_compatible_
<< ", trt_onnx_model_bytestream_size_: " << onnx_model_bytestream_size_
<< ", trt_op_types_to_exclude: " << op_types_to_exclude_;
<< ", trt_onnx_model_bytestream_size_: " << onnx_model_bytestream_size_;
}

TensorrtExecutionProvider::~TensorrtExecutionProvider() {
Expand Down Expand Up @@ -2442,18 +2434,6 @@ bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t&
return cycle_detected;
}

std::set<std::string> GetExcludedNodeSet(std::string node_list_to_exclude) {
std::set<std::string> set;
if (!node_list_to_exclude.empty()) {
std::stringstream node_list(node_list_to_exclude);
std::string node;
while (std::getline(node_list, node, ',')) {
set.insert(node);
}
}
return set;
}

std::vector<std::unique_ptr<ComputeCapability>>
TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
const IKernelLookup& /*kernel_lookup*/) const {
Expand Down Expand Up @@ -2486,14 +2466,17 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
std::vector<size_t> nodes_vector(number_of_ort_nodes);
std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0);

std::set<std::string> exclude_set = GetExcludedNodeSet(op_types_to_exclude_);
std::set<std::string> exclude_ops_set;

// Print excluded nodes, if any.
std::set<std::string>::iterator it;
for (it = exclude_set.begin(); it != exclude_set.end(); ++it) {
std::string op = *it;
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Exclude \"" << op << "\" from running on TRT, if any.";
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Remove \"" << op << "\" from trt_op_types_to_exclude or specify trt_op_types_to_exclude with empty string to include the op in the input to TRT parser. However, it still depends on TRT parser to determine the eligibility of this op for TRT.";
/*
* There is a known performance issue with the DDS ops (NonMaxSuppression, NonZero and RoiAlign) in TRT 10.
* TRT EP automatically excludes DDS ops from running on TRT.
*/
if (trt_version_ >= 100000 && trt_version_ < 110000) {
exclude_ops_set.insert("NonMaxSuppression");
exclude_ops_set.insert("NonZero");
exclude_ops_set.insert("RoiAlign");
LOGS_DEFAULT(VERBOSE) << "There is a known performance issue with the DDS ops (NonMaxSuppression, NonZero and RoiAlign) in TRT 10. TRT EP automatically excludes DDS ops from running on TRT, if applicable";
}

SubGraphCollection_t parser_nodes_vector, supported_nodes_vector;
Expand All @@ -2502,7 +2485,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,

/* Iterate all the nodes and exclude the node if:
* 1. It's a control flow op and its subgraph(s) is not fully TRT eligible.
* 2. It's in the exlucded set which specified by trt_op_types_to_exclude.
* 2. It's a DDS op.
*/
for (const auto& index : nodes_vector) {
const auto& node = graph.GetNode(node_index[index]);
Expand Down Expand Up @@ -2538,7 +2521,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
}

// Exclude any ops, if applicable
if (exclude_set.find(node->OpType()) != exclude_set.end()) {
if (exclude_ops_set.find(node->OpType()) != exclude_ops_set.end()) {
supported_node = false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ static const std::string kDumpEpContextModel = "ORT_DUMP_EP_CONTEXT_MODEL";
static const std::string kEpContextEmbedMode = "ORT_EP_CONTEXT_EMBED_MODE";
static const std::string kEpContextComputeCapabilityEnable = "ORT_EP_CONTEXT_COMPUTE_CAPABILITY_ENABLE";
static const std::string kEngineCachePrefix = "ORT_TENSORRT_CACHE_PREFIX";
static const std::string kOpTypesToExclude = "ORT_TENSORRT_OP_TYPES_TO_EXCLUDE";
// Old env variable for backward compatibility
static const std::string kEngineCachePath = "ORT_TENSORRT_ENGINE_CACHE_PATH";
} // namespace tensorrt_env_vars
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ constexpr const char* kDumpEpContextModel = "trt_dump_ep_context_model";
constexpr const char* kEngineHwCompatible = "trt_engine_hw_compatible";
constexpr const char* kONNXBytestream = "trt_onnx_bytestream";
constexpr const char* kONNXBytestreamSize = "trt_onnx_bytestream_size";
constexpr const char* kOpTypesToExclude = "trt_op_types_to_exclude";

} // namespace provider_option_names
} // namespace tensorrt
Expand Down Expand Up @@ -135,7 +134,6 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions
return Status::OK();
})
.AddAssignmentToReference(tensorrt::provider_option_names::kONNXBytestreamSize, info.onnx_bytestream_size)
.AddAssignmentToReference(tensorrt::provider_option_names::kOpTypesToExclude, info.op_types_to_exclude)
.Parse(options)); // add new provider option here.

info.user_compute_stream = user_compute_stream;
Expand Down Expand Up @@ -190,7 +188,6 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE
{tensorrt::provider_option_names::kEngineHwCompatible, MakeStringWithClassicLocale(info.engine_hw_compatible)},
{tensorrt::provider_option_names::kONNXBytestream, MakeStringWithClassicLocale(info.onnx_bytestream)},
{tensorrt::provider_option_names::kONNXBytestreamSize, MakeStringWithClassicLocale(info.onnx_bytestream_size)},
{tensorrt::provider_option_names::kOpTypesToExclude, MakeStringWithClassicLocale(info.op_types_to_exclude)},
};
return options;
}
Expand All @@ -209,7 +206,6 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor
const std::string kProfilesOptShapes_ = empty_if_null(info.trt_profile_opt_shapes);
const std::string kEpContextFilePath_ = empty_if_null(info.trt_ep_context_file_path);
const std::string kOnnxModelFolderPath_ = empty_if_null(info.trt_onnx_model_folder_path);
const std::string kOpTypesToExclude_ = empty_if_null(info.trt_op_types_to_exclude);

const ProviderOptions options{
{tensorrt::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
Expand Down Expand Up @@ -255,7 +251,6 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor
{tensorrt::provider_option_names::kEngineHwCompatible, MakeStringWithClassicLocale(info.trt_engine_hw_compatible)},
{tensorrt::provider_option_names::kONNXBytestream, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.trt_onnx_bytestream))},
{tensorrt::provider_option_names::kONNXBytestreamSize, MakeStringWithClassicLocale(info.trt_onnx_bytestream_size)},
{tensorrt::provider_option_names::kOpTypesToExclude, kOpTypesToExclude_},
};
return options;
}
Expand Down Expand Up @@ -360,6 +355,5 @@ void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options
trt_provider_options_v2.trt_engine_hw_compatible = internal_options.engine_hw_compatible;
trt_provider_options_v2.trt_onnx_bytestream = internal_options.onnx_bytestream;
trt_provider_options_v2.trt_onnx_bytestream_size = internal_options.onnx_bytestream_size;
trt_provider_options_v2.trt_op_types_to_exclude = copy_string_if_needed(internal_options.op_types_to_exclude);
}
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,6 @@ struct TensorrtExecutionProviderInfo {
int ep_context_embed_mode{0};
std::string engine_cache_prefix{""};
bool engine_hw_compatible{false};
// There is a known performance issue with the DDS ops (NonMaxSuppression, NonZero and RoiAlign) from TRT versions 10.0 to 10.7.
// TRT EP excludes DDS ops from running on TRT by default, user can override default value of trt_op_types_to_exclude with empty string to include all ops.
std::string op_types_to_exclude{"NonMaxSuppression,NonZero,RoiAlign"};

static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
static ProviderOptions ToProviderOptions(const TensorrtExecutionProviderInfo& info);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ struct Tensorrt_Provider : Provider {
info.engine_hw_compatible = options.trt_engine_hw_compatible != 0;
info.onnx_bytestream = options.trt_onnx_bytestream;
info.onnx_bytestream_size = options.trt_onnx_bytestream_size;
info.op_types_to_exclude = options.trt_op_types_to_exclude == nullptr ? "" : options.trt_op_types_to_exclude;

return std::make_shared<TensorrtProviderFactory>(info);
}
Expand Down
8 changes: 2 additions & 6 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2294,11 +2294,8 @@ ORT_API_STATUS_IMPL(OrtApis::UpdateTensorRTProviderOptions,
#ifdef USE_TENSORRT
onnxruntime::ProviderOptions provider_options_map;
for (size_t i = 0; i != num_keys; ++i) {
// Don't allow key and value to be empty except the value of trt_op_types_to_exclude
if (provider_options_keys[i] == nullptr ||
provider_options_keys[i][0] == '\0' ||
(provider_options_values[i] == nullptr && strcmp("trt_op_types_to_exclude", provider_options_keys[i])) ||
(provider_options_values[i][0] == '\0' && strcmp("trt_op_types_to_exclude", provider_options_keys[i]))) {
if (provider_options_keys[i] == nullptr || provider_options_keys[i][0] == '\0' ||
provider_options_values[i] == nullptr || provider_options_values[i][0] == '\0') {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "key/value cannot be empty");
}

Expand Down Expand Up @@ -2413,7 +2410,6 @@ ORT_API(void, OrtApis::ReleaseTensorRTProviderOptions, _Frees_ptr_opt_ OrtTensor
delete[] ptr->trt_profile_opt_shapes;
delete[] ptr->trt_ep_context_file_path;
delete[] ptr->trt_onnx_model_folder_path;
if (!ptr->trt_op_types_to_exclude) delete[] ptr->trt_op_types_to_exclude;
}

std::unique_ptr<OrtTensorRTProviderOptionsV2> p(ptr);
Expand Down
5 changes: 1 addition & 4 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
// and TRT EP instance, so it won't be released.)
std::string calibration_table, cache_path, cache_prefix, timing_cache_path, lib_path, trt_tactic_sources,
trt_extra_plugin_lib_paths, min_profile, max_profile, opt_profile, ep_context_file_path,
onnx_model_folder_path, trt_op_types_to_exclude{"NonMaxSuppression,NonZero,RoiAlign"};
onnx_model_folder_path;
auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
OrtTensorRTProviderOptionsV2 params;
Expand Down Expand Up @@ -824,9 +824,6 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_hw_compatible' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_op_types_to_exclude") {
trt_op_types_to_exclude = option.second;
params.trt_op_types_to_exclude = trt_op_types_to_exclude.c_str();
} else {
ORT_THROW("Invalid TensorRT EP option: ", option.first);
}
Expand Down
60 changes: 0 additions & 60 deletions onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -612,66 +612,6 @@ TEST(TensorrtExecutionProviderTest, EPContextNode) {
RunSession(session_object9, run_options, feeds, output_names, expected_dims_mul_m, expected_values_mul_m);
}

TEST(TensorrtExecutionProviderTest, ExcludeOpsTest) {
/* The mnist.onnx looks like this:
* Conv
* |
* Add
* .
* .
* |
* MaxPool
* |
* .
* .
* MaxPool
* |
* Reshape
* |
* MatMul
* .
* .
*
*/
PathString model_name = ORT_TSTR("testdata/mnist.onnx");
SessionOptions so;
so.session_logid = "TensorrtExecutionProviderExcludeOpsTest";
RunOptions run_options;
run_options.run_tag = so.session_logid;
InferenceSession session_object{so, GetEnvironment()};
auto cuda_provider = DefaultCudaExecutionProvider();
auto cpu_allocator = cuda_provider->CreatePreferredAllocators()[1];
std::vector<int64_t> dims_op_x = {1, 1, 28, 28};
std::vector<float> values_op_x(784, 1.0f); // 784=1*1*28*28
OrtValue ml_value_x;
CreateMLValue<float>(cpu_allocator, dims_op_x, values_op_x, &ml_value_x);
NameMLValMap feeds;
feeds.insert(std::make_pair("Input3", ml_value_x));

// prepare outputs
std::vector<std::string> output_names;
output_names.push_back("Plus214_Output_0");
std::vector<OrtValue> fetches;

RemoveCachesByType("./", ".engine");
OrtTensorRTProviderOptionsV2 params;
params.trt_engine_cache_enable = 1;
params.trt_op_types_to_exclude = "MaxPool";
std::unique_ptr<IExecutionProvider> execution_provider = TensorrtExecutionProviderWithOptions(&params);
EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK());
auto status = session_object.Load(model_name);
ASSERT_TRUE(status.IsOK());
status = session_object.Initialize();
ASSERT_TRUE(status.IsOK());
status = session_object.Run(run_options, feeds, output_names, &fetches);
ASSERT_TRUE(status.IsOK());

std::vector<fs::path> engine_files;
engine_files = GetCachesByType("./", ".engine");
// The whole graph should be partitioned into 3 TRT subgraphs and 2 cpu nodes
ASSERT_EQ(engine_files.size(), 3);
}

TEST(TensorrtExecutionProviderTest, TRTPluginsCustomOpTest) {
PathString model_name = ORT_TSTR("testdata/trt_plugin_custom_op_test.onnx");
SessionOptions so;
Expand Down

0 comments on commit 9c89851

Please sign in to comment.