Skip to content

Commit

Permalink
Overwrite _ctx model when embedeed 1 ep context node with dynamic inp…
Browse files Browse the repository at this point in the history
…ut and change graph_view access
  • Loading branch information
jingyanwangms committed Feb 15, 2025
1 parent 61df552 commit 02015d4
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 74 deletions.
60 changes: 39 additions & 21 deletions onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,15 @@ extern TensorrtLogger& GetTensorrtLogger(bool verbose_log);
* Note: Please see more details about "EPContext" contrib op in contrib_defs.cc
*/
bool GraphHasCtxNode(const GraphViewer& graph_viewer) {
for (int i = 0; i < graph_viewer.MaxNodeIndex(); ++i) {
auto node = graph_viewer.GetNode(i);
for (auto node_index: graph_viewer.GetNodesInTopologicalOrder()) {
auto node = graph_viewer.GetNode(node_index);
if (node != nullptr && node->OpType() == EPCONTEXT_OP) {
return true;
}
}
return false;
}

int FindCtxNodeInGraph(const GraphViewer& graph_viewer) {
// Assumes there's only 1 context node in this subgraph (graph_viewer)
// Returns index of node
for (int i = 0; i < graph_viewer.MaxNodeIndex(); ++i) {
auto node = graph_viewer.GetNode(i);
if (node != nullptr && node->OpType() == EPCONTEXT_OP) {
return i;
}
}
return -1;
}

const std::filesystem::path& GetModelPath(const GraphViewer& graph_viewer) {
// find the top level graph
const Graph* cur_graph = &graph_viewer.GetGraph();
Expand All @@ -52,6 +40,27 @@ const std::filesystem::path& GetModelPath(const GraphViewer& graph_viewer) {
return main_graph.ModelPath();
}

/*
* Update ep_cache_context attribute of the EP context node with the given engine binary data
*/
void UpdateCtxNodeModelEngineContext(ONNX_NAMESPACE::ModelProto* model_proto,
char* engine_data,
size_t size) {
ONNX_NAMESPACE::GraphProto* graph_proto = model_proto->mutable_graph();
ONNX_NAMESPACE::NodeProto* node_proto = graph_proto->mutable_node(0);

for (int i = 0; i < node_proto->attribute_size(); ++i) {
ONNX_NAMESPACE::AttributeProto* attribute_proto = node_proto->mutable_attribute(i);
if (attribute_proto->name() == EP_CACHE_CONTEXT) {
std::string engine_data_str = "";
if (size > 0) {
engine_data_str.assign(engine_data, size);
}
attribute_proto->set_s(engine_data_str);
}
}
}

/*
* Create "EP context node" model where engine information is embedded
*/
Expand Down Expand Up @@ -188,6 +197,13 @@ std::string GetCtxModelPath(const std::string& ep_context_file_path,
return ctx_model_path;
}

void DumpCtxModel(ONNX_NAMESPACE::ModelProto* model_proto,
const std::string& ctx_model_path) {
std::fstream dump(ctx_model_path, std::ios::out | std::ios::trunc | std::ios::binary);
model_proto->SerializeToOstream(dump);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Dumped " + ctx_model_path;
}

bool IsAbsolutePath(const std::string& path_string) {
#ifdef _WIN32
onnxruntime::PathString ort_path_string = onnxruntime::ToPathString(path_string);
Expand Down Expand Up @@ -241,11 +257,12 @@ bool IsWeightStrippedEngineCache(std::filesystem::path& engine_cache_path) {
return engine_cache_path.stem().extension().string() == ".stripped";
}

Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph_viewer, const int ctx_node_idx) {
if (!ValidateEPCtxNode(graph_viewer, ctx_node_idx)) {
Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph_viewer) {
if (!ValidateEPCtxNode(graph_viewer)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "It's not a valid EP Context node");
}
auto node = graph_viewer.GetNode(ctx_node_idx);
// graph_viewer.GetNodesInTopologicalOrder().size() == 1 validated in ValidateEPCtxNode
auto node = graph_viewer.GetNode(graph_viewer.GetNodesInTopologicalOrder()[0]);
auto& attrs = node->GetAttributes();

const int64_t embed_mode = attrs.at(EMBED_MODE).i();
Expand Down Expand Up @@ -355,10 +372,11 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph
/*
* The sanity check for EP context contrib op.
*/
bool TensorRTCacheModelHandler::ValidateEPCtxNode(const GraphViewer& graph_viewer, const int ctx_node_idx) {
assert(graph_viewer.NumberOfNodes() == 1);
assert(graph_viewer.GetNode(ctx_node_idx)->OpType() == EPCONTEXT_OP);
auto node = graph_viewer.GetNode(ctx_node_idx);
bool TensorRTCacheModelHandler::ValidateEPCtxNode(const GraphViewer& graph_viewer) {
const auto& subgraph_node_list = graph_viewer.GetNodesInTopologicalOrder();
assert(subgraph_node_list.size() == 1); // There should only be 1 node in filtered graph
const auto node = graph_viewer.GetNode(subgraph_node_list[0]);
assert(node->OpType() == EPCONTEXT_OP);
auto& attrs = node->GetAttributes();

// Show the warning if compute capability is not matched
Expand Down
10 changes: 7 additions & 3 deletions onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@ std::string GetCtxModelPath(const std::string& ep_context_file_path,
const std::string& original_model_path);
bool IsAbsolutePath(const std::string& path_string);
bool IsRelativePathToParentPath(const std::string& path_string);

void DumpCtxModel(ONNX_NAMESPACE::ModelProto* model_proto,
const std::string& ctx_model_path);
void UpdateCtxNodeModelEngineContext(ONNX_NAMESPACE::ModelProto* model_proto,
char* engine_data,
size_t size);
class TensorRTCacheModelHandler {
public:
TensorRTCacheModelHandler(std::unique_ptr<nvinfer1::ICudaEngine>* trt_engine,
Expand All @@ -65,9 +69,9 @@ class TensorRTCacheModelHandler {
}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorRTCacheModelHandler);

bool ValidateEPCtxNode(const GraphViewer& graph_viewer, const int ctx_node_idx);
bool ValidateEPCtxNode(const GraphViewer& graph_viewer);

Status GetEpContextFromGraph(const GraphViewer& graph_viewer, const int ctx_node_idx);
Status GetEpContextFromGraph(const GraphViewer& graph_viewer);

private:
std::unique_ptr<nvinfer1::ICudaEngine>* trt_engine_;
Expand Down
110 changes: 62 additions & 48 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1732,43 +1732,43 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] CUDA version is " << cuda_version_;

LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] TensorRT provider options: "
<< "device_id: " << device_id_
<< ", trt_max_partition_iterations: " << max_partition_iterations_
<< ", trt_min_subgraph_size: " << min_subgraph_size_
<< ", trt_max_workspace_size: " << max_workspace_size_
<< ", trt_fp16_enable: " << fp16_enable_
<< ", trt_int8_enable: " << int8_enable_
<< ", trt_int8_calibration_cache_name: " << int8_calibration_cache_name_
<< ", int8_calibration_cache_available: " << int8_calibration_cache_available_
<< ", trt_int8_use_native_tensorrt_calibration_table: " << int8_use_native_tensorrt_calibration_table_
<< ", trt_dla_enable: " << dla_enable_
<< ", trt_dla_core: " << dla_core_
<< ", trt_dump_subgraphs: " << dump_subgraphs_
<< ", trt_engine_cache_enable: " << engine_cache_enable_
<< ", trt_weight_stripped_engine_enable: " << weight_stripped_engine_enable_
<< ", trt_onnx_model_folder_path: " << onnx_model_folder_path_
<< ", trt_cache_path: " << cache_path_
<< ", trt_global_cache_path: " << global_cache_path_
<< ", trt_engine_decryption_enable: " << engine_decryption_enable_
<< ", trt_engine_decryption_lib_path: " << engine_decryption_lib_path_
<< ", trt_force_sequential_engine_build: " << force_sequential_engine_build_
<< ", trt_context_memory_sharing_enable: " << context_memory_sharing_enable_
<< ", trt_layer_norm_fp32_fallback: " << layer_norm_fp32_fallback_
<< ", trt_build_heuristics_enable: " << build_heuristics_enable_
<< ", trt_sparsity_enable: " << sparsity_enable_
<< ", trt_builder_optimization_level: " << builder_optimization_level_
<< ", trt_auxiliary_streams: " << auxiliary_streams_
<< ", trt_tactic_sources: " << tactic_sources_
<< ", trt_profile_min_shapes: " << profile_min_shapes
<< ", trt_profile_max_shapes: " << profile_max_shapes
<< ", trt_profile_opt_shapes: " << profile_opt_shapes
<< ", trt_cuda_graph_enable: " << cuda_graph_enable_
<< ", trt_dump_ep_context_model: " << dump_ep_context_model_
<< ", trt_ep_context_file_path: " << ep_context_file_path_
<< ", trt_ep_context_embed_mode: " << ep_context_embed_mode_
<< ", trt_cache_prefix: " << cache_prefix_
<< ", trt_engine_hw_compatible: " << engine_hw_compatible_
<< ", trt_onnx_model_bytestream_size_: " << onnx_model_bytestream_size_;
<< "device_id: " << device_id_
<< ", trt_max_partition_iterations: " << max_partition_iterations_
<< ", trt_min_subgraph_size: " << min_subgraph_size_
<< ", trt_max_workspace_size: " << max_workspace_size_
<< ", trt_fp16_enable: " << fp16_enable_
<< ", trt_int8_enable: " << int8_enable_
<< ", trt_int8_calibration_cache_name: " << int8_calibration_cache_name_
<< ", int8_calibration_cache_available: " << int8_calibration_cache_available_
<< ", trt_int8_use_native_tensorrt_calibration_table: " << int8_use_native_tensorrt_calibration_table_
<< ", trt_dla_enable: " << dla_enable_
<< ", trt_dla_core: " << dla_core_
<< ", trt_dump_subgraphs: " << dump_subgraphs_
<< ", trt_engine_cache_enable: " << engine_cache_enable_
<< ", trt_weight_stripped_engine_enable: " << weight_stripped_engine_enable_
<< ", trt_onnx_model_folder_path: " << onnx_model_folder_path_
<< ", trt_cache_path: " << cache_path_
<< ", trt_global_cache_path: " << global_cache_path_
<< ", trt_engine_decryption_enable: " << engine_decryption_enable_
<< ", trt_engine_decryption_lib_path: " << engine_decryption_lib_path_
<< ", trt_force_sequential_engine_build: " << force_sequential_engine_build_
<< ", trt_context_memory_sharing_enable: " << context_memory_sharing_enable_
<< ", trt_layer_norm_fp32_fallback: " << layer_norm_fp32_fallback_
<< ", trt_build_heuristics_enable: " << build_heuristics_enable_
<< ", trt_sparsity_enable: " << sparsity_enable_
<< ", trt_builder_optimization_level: " << builder_optimization_level_
<< ", trt_auxiliary_streams: " << auxiliary_streams_
<< ", trt_tactic_sources: " << tactic_sources_
<< ", trt_profile_min_shapes: " << profile_min_shapes
<< ", trt_profile_max_shapes: " << profile_max_shapes
<< ", trt_profile_opt_shapes: " << profile_opt_shapes
<< ", trt_cuda_graph_enable: " << cuda_graph_enable_
<< ", trt_dump_ep_context_model: " << dump_ep_context_model_
<< ", trt_ep_context_file_path: " << ep_context_file_path_
<< ", trt_ep_context_embed_mode: " << ep_context_embed_mode_
<< ", trt_cache_prefix: " << cache_prefix_
<< ", trt_engine_hw_compatible: " << engine_hw_compatible_
<< ", trt_onnx_model_bytestream_size_: " << onnx_model_bytestream_size_;
}

TensorrtExecutionProvider::~TensorrtExecutionProvider() {
Expand Down Expand Up @@ -2811,11 +2811,9 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
}

Status status;
int ctx_node_idx = FindCtxNodeInGraph(graph_body_viewer);
if (ctx_node_idx >= 0) {
if (GraphHasCtxNode(graph_body_viewer)) {
status = CreateNodeComputeInfoFromPrecompiledEngine(graph_body_viewer,
fused_node,
ctx_node_idx,
input_map,
output_map,
node_compute_funcs);
Expand Down Expand Up @@ -2910,6 +2908,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
bool has_explicit_profile = false;
bool apply_explicit_profile = false;
int num_profiles = 0;
is_single_node_epcontext_graph = graph_body_viewer.NumberOfNodes() == 1 && GraphHasCtxNode(graph_body_viewer);
std::vector<nvinfer1::IOptimizationProfile*> trt_profiles;

// Following c++ map data structure is used to help serialize/deserialize profiles where it saves dynamic shape dimension(s) and min/max/opt values for dynamic shape input tensor.
Expand Down Expand Up @@ -3200,10 +3199,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
}

// Generate file name for dumping ep context model
if (dump_ep_context_model_ && ctx_model_path_.empty()) {
ctx_model_path_ = GetCtxModelPath(ep_context_file_path_, model_path_);
}


if (!has_dynamic_shape) {
std::string timing_cache_path = "";
bool engine_update = false;
Expand Down Expand Up @@ -3313,7 +3309,6 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
// Serialize engine profile if it has explicit profiles
if (has_explicit_profile) {
SerializeProfileV2(profile_cache_path, input_explicit_shape_ranges);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path;
}

if (engine_decryption_enable_) {
Expand Down Expand Up @@ -3473,6 +3468,20 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
GetLogger());

trt_ep_context_models.emplace_back(std::move(trt_ep_context_model_ptr));
// if (ep_context_embed_mode_ == 0 && is_single_node_epcontext_graph) {

// // Serialize modelproto to string
// auto& graph_build = trt_ep_context_model_ptr->MainGraph();
// auto new_graph_viewer = graph_build.CreateGraphViewer();
// auto& metadata = graph_body_viewer.GetGraph().GetModel().MetaData();
// auto model = new_graph_viewer->CreateModel(*logger, metadata);
// auto model_proto = model->ToProto();
// new_graph_viewer->ToProto(*model_proto->mutable_graph(), true, true);
// model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);

// std::cout << "Dumping EP context model to " << ctx_model_path_ << std::endl;
// DumpCtxModel(model_proto_.get(), ctx_model_path_);
// }
}

// Create function state
Expand Down Expand Up @@ -3843,7 +3852,13 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView

// Give a warning that ep context need to be regenerated
if (dump_ep_context_model_ && ep_context_embed_mode_) {
LOGS_DEFAULT(WARNING) << "Engine was updated during inference due to dynamic input changed change. Please regenerate EP context model.";
std::cout << "Updating model during inference due to dynamic input changed change." << std::endl;

Check warning on line 3855 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <iostream> for cout [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:3855: Add #include <iostream> for cout [build/include_what_you_use] [4]
if (is_single_node_epcontext_graph) {
UpdateCtxNodeModelEngineContext(model_proto_.get(), reinterpret_cast<char*>(serialized_engine->data()), serialized_engine->size());
DumpCtxModel(model_proto_.get(), ctx_model_path_);
} else {
LOGS_DEFAULT(WARNING) << "Engine was updated during inference due to dynamic input changed change. Please regenerate EP context model.";
}
}
context_update = true;

Expand Down Expand Up @@ -4068,7 +4083,6 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView

Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const GraphViewer& graph_body_viewer,
const Node& fused_node,
const int ctx_node_idx,
std::unordered_map<std::string, size_t>& input_map,
std::unordered_map<std::string, size_t>& output_map,
std::vector<NodeComputeInfo>& node_compute_funcs) {
Expand All @@ -4091,7 +4105,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con
onnx_model_bytestream_,
onnx_model_bytestream_size_,
detailed_build_log_);
auto status = trt_cache_model_handler.GetEpContextFromGraph(graph_body_viewer, ctx_node_idx);
auto status = trt_cache_model_handler.GetEpContextFromGraph(graph_body_viewer);
if (status != Status::OK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,8 @@ class TensorrtExecutionProvider : public IExecutionProvider {
std::string ep_cache_context_attr_;
std::string engine_cache_relative_path_to_context_model_dir;
std::unique_ptr<ONNX_NAMESPACE::ModelProto> model_proto_ = ONNX_NAMESPACE::ModelProto::Create();

bool is_single_node_epcontext_graph = false;

std::unordered_set<std::string> control_flow_op_set_ = {"If", "Loop", "Scan"};
mutable std::unordered_map<std::string, std::unique_ptr<SubGraphContext>> subgraph_context_map_;

Expand Down Expand Up @@ -569,7 +570,6 @@ class TensorrtExecutionProvider : public IExecutionProvider {
*/
Status CreateNodeComputeInfoFromPrecompiledEngine(const GraphViewer& graph_body_viewer,
const Node& fused_node,
const int ctx_node_idx,
std::unordered_map<std::string, size_t>& input_map,
std::unordered_map<std::string, size_t>& output_map,
std::vector<NodeComputeInfo>& node_compute_funcs);
Expand Down

0 comments on commit 02015d4

Please sign in to comment.