Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorRT] Support Multiple EP Context #23294

Open
wants to merge 39 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
00cecb7
Support EP context partition
jingyanwangms Dec 21, 2024
bf7c0df
Unit test and perftest fix
jingyanwangms Jan 7, 2025
e8880c5
perf test and EP updates
jingyanwangms Jan 8, 2025
48c882e
Updated session option merge and unit test
jingyanwangms Jan 8, 2025
842b3e1
Clean up
jingyanwangms Jan 9, 2025
a63973e
Removee debug logs
jingyanwangms Jan 9, 2025
0ac2ac7
Set correct ep_cache_context
jingyanwangms Jan 15, 2025
cfaee44
Fix orphaned output
jingyanwangms Jan 20, 2025
41958dc
Fix GetCapabilities logic
jingyanwangms Jan 21, 2025
44818f1
Add tensor(bool) to valid EPContext input type (Zcode model error)
jingyanwangms Jan 22, 2025
1fa99bd
merge with main
jingyanwangms Jan 24, 2025
90c2554
Skip memory test
jingyanwangms Jan 24, 2025
194b90c
Add unit test
jingyanwangms Jan 28, 2025
2786f88
Try memory test
jingyanwangms Jan 28, 2025
91e4e5e
lint runner
jingyanwangms Jan 29, 2025
2d8b2ba
Merge branch 'main' into jingywa/epcontext
jingyanwangms Jan 29, 2025
3a56834
Remove old UT carried from old branch
jingyanwangms Jan 29, 2025
afe0a26
Fix windows build error and regenerate ContribOperators.md
jingyanwangms Jan 31, 2025
169e13d
Update onnxruntime/core/providers/tensorrt/tensorrt_execution_provide…
jingyanwangms Jan 31, 2025
3db4f82
Fix windows build
jingyanwangms Jan 31, 2025
1b37a57
Merge
jingyanwangms Feb 1, 2025
b8464fa
Fix TensorrtExecutionProviderTest.SessionCreationWithMultiThreadsAndI…
jingyanwangms Feb 1, 2025
b18f425
Fix CI
jingyanwangms Feb 3, 2025
7ef865c
Add tests
jingyanwangms Feb 4, 2025
253a599
Update UT
jingyanwangms Feb 5, 2025
1fd22ef
fix typo
jingyanwangms Feb 5, 2025
bb393a9
Fix UT
jingyanwangms Feb 5, 2025
8fc043f
Fix nested test
jingyanwangms Feb 5, 2025
24cde8c
lint
jingyanwangms Feb 5, 2025
61df552
Disable SessionCreationWithMultiThreadsAndInferenceWithMultiThreads t…
jingyanwangms Feb 6, 2025
02015d4
Overwrite _ctx model when embedeed 1 ep context node with dynamic inp…
jingyanwangms Feb 15, 2025
ef119d9
Add EPContextEmbeddedDynamicShouldRegenerate
jingyanwangms Feb 17, 2025
a00f236
embedded dynamic input change rewrite model
jingyanwangms Feb 17, 2025
a15f663
Lint
jingyanwangms Feb 18, 2025
cb29978
Merge branch 'main' into jingywa/epcontext
jingyanwangms Feb 18, 2025
a998d98
Revert "Fix ACL option parsing (#23586)"
jingyanwangms Feb 18, 2025
f09f319
Revert "Fix attention fusion in conformer encoder (#23711)"
jingyanwangms Feb 18, 2025
4396aa9
Log UT error
jingyanwangms Feb 18, 2025
1f16d76
Add logging for UT
jingyanwangms Feb 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -1625,7 +1625,7 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(float16), tensor(float), tensor(double), tensor(bfloat16)</dt>
<dt><tt>T</tt> : tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(bool), tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(float16), tensor(float), tensor(double), tensor(bfloat16)</dt>
<dd>Constrain input and output types.</dd>
</dl>

Expand Down Expand Up @@ -1754,7 +1754,7 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float), tensor(float16), tensor(bfloat16)</dt>
<dt><tt>T</tt> : tensor(float), tensor(double), tensor(float16), tensor(bfloat16)</dt>
<dd>Constrain input and output types to float or half tensors.</dd>
</dl>

Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3365,6 +3365,7 @@ void RegisterContribSchemas() {
"tensor(int16)",
"tensor(int32)",
"tensor(int64)",
"tensor(bool)",
"tensor(uint8)",
"tensor(uint16)",
"tensor(uint32)",
Expand Down
93 changes: 34 additions & 59 deletions onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,18 @@ bool GraphHasCtxNode(const GraphViewer& graph_viewer) {
return false;
}

int FindCtxNodeInGraph(const GraphViewer& graph_viewer) {
// Assumes there's only 1 context node in this subgraph (graph_viewer)
// Returns index of node
for (int i = 0; i < graph_viewer.MaxNodeIndex(); ++i) {
auto node = graph_viewer.GetNode(i);
if (node != nullptr && node->OpType() == EPCONTEXT_OP) {
return i;
}
}
return -1;
}

const std::filesystem::path& GetModelPath(const GraphViewer& graph_viewer) {
// find the top level graph
const Graph* cur_graph = &graph_viewer.GetGraph();
Expand All @@ -40,38 +52,18 @@ const std::filesystem::path& GetModelPath(const GraphViewer& graph_viewer) {
return main_graph.ModelPath();
}

/*
* Update ep_cache_context attribute of the EP context node with the given engine binary data
*/
void UpdateCtxNodeModelEngineContext(ONNX_NAMESPACE::ModelProto* model_proto,
char* engine_data,
size_t size) {
ONNX_NAMESPACE::GraphProto* graph_proto = model_proto->mutable_graph();
ONNX_NAMESPACE::NodeProto* node_proto = graph_proto->mutable_node(0);

for (int i = 0; i < node_proto->attribute_size(); ++i) {
ONNX_NAMESPACE::AttributeProto* attribute_proto = node_proto->mutable_attribute(i);
if (attribute_proto->name() == EP_CACHE_CONTEXT) {
std::string engine_data_str = "";
if (size > 0) {
engine_data_str.assign(engine_data, size);
}
attribute_proto->set_s(engine_data_str);
}
}
}

/*
* Create "EP context node" model where engine information is embedded
*/
ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer,
const std::string engine_cache_path,
char* engine_data,
size_t size,
const int64_t embed_mode,
const std::string compute_capability,
const std::string onnx_model_path,
const logging::Logger* logger) {
std::unique_ptr<Model> CreateCtxModel(const GraphViewer& graph_viewer,
const std::string fused_subgraph_name,
const std::string engine_cache_path,
char* engine_data,
size_t size,
const int64_t embed_mode,
const std::string compute_capability,
const std::string onnx_model_path,
const logging::Logger* logger) {
auto model_build = graph_viewer.CreateModel(*logger);
auto& graph_build = model_build->MainGraph();

Expand Down Expand Up @@ -123,18 +115,11 @@ ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer,
node_attributes->emplace(ONNX_MODEL_FILENAME, *attr_3);

// Create EP context node
graph_build.AddNode(EPCONTEXT_OP, EPCONTEXT_OP, "", inputs, outputs, node_attributes.get(), EPCONTEXT_OP_DOMAIN);
ORT_ENFORCE(graph_build.Resolve().IsOK());

// Serialize modelproto to string
auto new_graph_viewer = graph_build.CreateGraphViewer();
auto& metadata = graph_viewer.GetGraph().GetModel().MetaData();
auto model = new_graph_viewer->CreateModel(*logger, metadata);
auto model_proto = model->ToProto();
new_graph_viewer->ToProto(*model_proto->mutable_graph(), true, true);
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);

return model_proto.release();
graph_build.AddNode(fused_subgraph_name, EPCONTEXT_OP, "", inputs, outputs, node_attributes.get(), EPCONTEXT_OP_DOMAIN);
auto status = graph_build.Resolve();
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());

return model_build;
}

/*
Expand Down Expand Up @@ -203,17 +188,6 @@ std::string GetCtxModelPath(const std::string& ep_context_file_path,
return ctx_model_path;
}

/*
* Dump "EP context" model
*
*/
void DumpCtxModel(ONNX_NAMESPACE::ModelProto* model_proto,
const std::string& ctx_model_path) {
std::fstream dump(ctx_model_path, std::ios::out | std::ios::trunc | std::ios::binary);
model_proto->SerializeToOstream(dump);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Dumped " + ctx_model_path;
}

bool IsAbsolutePath(const std::string& path_string) {
#ifdef _WIN32
onnxruntime::PathString ort_path_string = onnxruntime::ToPathString(path_string);
Expand Down Expand Up @@ -267,11 +241,11 @@ bool IsWeightStrippedEngineCache(std::filesystem::path& engine_cache_path) {
return engine_cache_path.stem().extension().string() == ".stripped";
}

Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph_viewer) {
if (!ValidateEPCtxNode(graph_viewer)) {
Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph_viewer, const int ctx_node_idx) {
if (!ValidateEPCtxNode(graph_viewer, ctx_node_idx)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "It's not a valid EP Context node");
}
auto node = graph_viewer.GetNode(0);
auto node = graph_viewer.GetNode(ctx_node_idx);
auto& attrs = node->GetAttributes();

const int64_t embed_mode = attrs.at(EMBED_MODE).i();
Expand Down Expand Up @@ -381,14 +355,14 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph
/*
* The sanity check for EP context contrib op.
*/
bool TensorRTCacheModelHandler::ValidateEPCtxNode(const GraphViewer& graph_viewer) {
bool TensorRTCacheModelHandler::ValidateEPCtxNode(const GraphViewer& graph_viewer, const int ctx_node_idx) {
assert(graph_viewer.NumberOfNodes() == 1);
assert(graph_viewer.GetNode(0)->OpType() == EPCONTEXT_OP);
auto node = graph_viewer.GetNode(0);
assert(graph_viewer.GetNode(ctx_node_idx)->OpType() == EPCONTEXT_OP);
auto node = graph_viewer.GetNode(ctx_node_idx);
auto& attrs = node->GetAttributes();

// Show the warning if compute capability is not matched
if (attrs.count(COMPUTE_CAPABILITY) > 0) {
if (attrs.find(COMPUTE_CAPABILITY) != attrs.end() && attrs.count(COMPUTE_CAPABILITY) > 0) {
std::string model_compute_capability = attrs.at(COMPUTE_CAPABILITY).s();
// Verify if engine was compiled with ampere+ hardware compatibility enabled
if (model_compute_capability == "80+") {
Expand All @@ -415,4 +389,5 @@ bool TensorRTCacheModelHandler::ValidateEPCtxNode(const GraphViewer& graph_viewe

return true;
}

} // namespace onnxruntime
28 changes: 13 additions & 15 deletions onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,23 @@ static const std::string EPCONTEXT_WARNING =
for the best model loading time";

bool GraphHasCtxNode(const GraphViewer& graph_viewer);
int FindCtxNodeInGraph(const GraphViewer& graph_viewer);

const std::filesystem::path& GetModelPath(const GraphViewer& graph_viewer);
std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path);
ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer,
const std::string engine_cache_path,
char* engine_data,
size_t size,
const int64_t embed_mode,
const std::string compute_capability,
const std::string onnx_model_path,
const logging::Logger* logger);
std::unique_ptr<Model> CreateCtxModel(const GraphViewer& graph_viewer,
const std::string fused_subgraph_name,
const std::string engine_cache_path,
char* engine_data,
size_t size,
const int64_t embed_mode,
const std::string compute_capability,
const std::string onnx_model_path,
const logging::Logger* logger);
std::string GetCtxModelPath(const std::string& ep_context_file_path,
const std::string& original_model_path);
bool IsAbsolutePath(const std::string& path_string);
bool IsRelativePathToParentPath(const std::string& path_string);
void DumpCtxModel(ONNX_NAMESPACE::ModelProto* model_proto,
const std::string& ctx_model_path);
void UpdateCtxNodeModelEngineContext(ONNX_NAMESPACE::ModelProto* model_proto,
char* engine_data,
size_t size);

class TensorRTCacheModelHandler {
public:
Expand All @@ -67,9 +65,9 @@ class TensorRTCacheModelHandler {
}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorRTCacheModelHandler);

bool ValidateEPCtxNode(const GraphViewer& graph_viewer);
bool ValidateEPCtxNode(const GraphViewer& graph_viewer, const int ctx_node_idx);

Status GetEpContextFromGraph(const GraphViewer& graph_viewer);
Status GetEpContextFromGraph(const GraphViewer& graph_viewer, const int ctx_node_idx);

private:
std::unique_ptr<nvinfer1::ICudaEngine>* trt_engine_;
Expand Down
Loading
Loading