Skip to content

Commit

Permalink
add new UT
Browse files Browse the repository at this point in the history
  • Loading branch information
HectorSVC committed Feb 15, 2025
1 parent c8cfc2e commit c8eb395
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 27 deletions.
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -470,8 +470,10 @@ Status QnnBackendManager::InitializeProfiling() {
QnnProfile_Level_t qnn_profile_level = QNN_PROFILE_LEVEL_BASIC;
if (ProfilingLevel::BASIC == profiling_level_merge_) {
qnn_profile_level = QNN_PROFILE_LEVEL_BASIC;
LOGS_DEFAULT(VERBOSE) << "Profiling level set to basic.";
} else if (ProfilingLevel::DETAILED == profiling_level_merge_) {
qnn_profile_level = QNN_PROFILE_LEVEL_DETAILED;
LOGS_DEFAULT(VERBOSE) << "Profiling level set to detailed.";
}
Qnn_ErrorHandle_t result = qnn_interface_.profileCreate(backend_handle_, qnn_profile_level, &profile_backend_handle_);
ORT_RETURN_IF(QNN_PROFILE_NO_ERROR != result, "Failed to create QNN profile! Error: ", QnnErrorHandleToString(result));
Expand Down
103 changes: 76 additions & 27 deletions onnxruntime/test/providers/qnn/qnn_ep_context_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1053,34 +1053,28 @@ static void CreateQdqModel(const std::string& model_file_name, const Logger& log
static void DumpModelWithSharedCtx(const ProviderOptions& provider_options,
const std::string& onnx_model_path1,
const std::string& onnx_model_path2) {
SessionOptions so;
so.session_logid = "qnn_ctx_model_logger";
ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"));
ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0"));
RunOptions run_options;
run_options.run_tag = so.session_logid;
Ort::SessionOptions so;
so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1");
so.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0");
// enable ep.share_ep_contexts so that QNNEP share the QnnBackendManager across sessions
so.AddConfigEntry(kOrtSessionOptionShareEpContexts, "1");

auto qnn_ep = QnnExecutionProviderWithOptions(provider_options, &so);
std::shared_ptr<IExecutionProvider> qnn_ep_shared(std::move(qnn_ep));
so.AppendExecutionProvider("QNN", provider_options);

InferenceSessionWrapper session_object1{so, GetEnvironment()};
ASSERT_STATUS_OK(session_object1.RegisterExecutionProvider(qnn_ep_shared));
ASSERT_STATUS_OK(session_object1.Load(ToPathString(onnx_model_path1)));
ASSERT_STATUS_OK(session_object1.Initialize());
// Create 2 sessions to generate context binary models, the 1st session will share the QnnBackendManager
// to the 2nd session, so graphs from these 2 models are all included in the 2nd context binary
Ort::Session session1(*ort_env, ToPathString(onnx_model_path1).c_str(), so);

InferenceSessionWrapper session_object2{so, GetEnvironment()};
ASSERT_STATUS_OK(session_object2.RegisterExecutionProvider(qnn_ep_shared));
ASSERT_STATUS_OK(session_object2.Load(ToPathString(onnx_model_path2)));
ASSERT_STATUS_OK(session_object2.Initialize());
Ort::Session session2(*ort_env, ToPathString(onnx_model_path2).c_str(), so);
}

// from the last context ache Onnx model, find the EPContext node with main_context=1,
// and get the QNN context binary file name, thie context binary contains all graphs from all Onnx models
static void GetLastContextBinaryFileName(const std::string last_onnx_ctx_file,
std::string& last_ctx_bin_file,
const Logger& logger) {
// from the context ache Onnx model, find the EPContext node with main_context=1,
// and get the QNN context binary file name
static void GetContextBinaryFileName(const std::string onnx_ctx_file,
std::string& last_ctx_bin_file,
const Logger& logger) {
std::shared_ptr<Model> ctx_model;
ASSERT_STATUS_OK(Model::Load(ToPathString(last_onnx_ctx_file), ctx_model, nullptr, logger));
ASSERT_STATUS_OK(Model::Load(ToPathString(onnx_ctx_file), ctx_model, nullptr, logger));
auto& ctx_graph = ctx_model->MainGraph();
for (auto& node : ctx_graph.Nodes()) {
if (node.OpType() == "EPContext") {
Expand Down Expand Up @@ -1172,10 +1166,10 @@ TEST_F(QnnHTPBackendTests, QnnContextShareAcrossSessions1) {

DumpModelWithSharedCtx(provider_options, onnx_model_paths[0], onnx_model_paths[1]);

// Get the last context binary file name
// Get the last context binary file name, the latest context binary file holds all graphs generated from all models
std::string last_qnn_ctx_binary_file_name;
GetLastContextBinaryFileName(ctx_model_paths.back(), last_qnn_ctx_binary_file_name,
DefaultLoggingManager().DefaultLogger());
GetContextBinaryFileName(ctx_model_paths.back(), last_qnn_ctx_binary_file_name,
DefaultLoggingManager().DefaultLogger());
EXPECT_TRUE(!last_qnn_ctx_binary_file_name.empty());

// Update generated context cache Onnx model to make the main EPContext node point to
Expand Down Expand Up @@ -1272,8 +1266,8 @@ TEST_F(QnnHTPBackendTests, QnnContextShareAcrossSessions2) {

// Get the last context binary file name
std::string last_qnn_ctx_binary_file_name;
GetLastContextBinaryFileName(ctx_model_paths.back(), last_qnn_ctx_binary_file_name,
DefaultLoggingManager().DefaultLogger());
GetContextBinaryFileName(ctx_model_paths.back(), last_qnn_ctx_binary_file_name,
DefaultLoggingManager().DefaultLogger());
EXPECT_TRUE(!last_qnn_ctx_binary_file_name.empty());

// Update generated context cache Onnx model to make the main EPContext node point to
Expand Down Expand Up @@ -1336,6 +1330,61 @@ TEST_F(QnnHTPBackendTests, QnnContextShareAcrossSessions2) {
}
std::remove(last_qnn_ctx_binary_file_name.c_str());
}

// For Ort sessions to generate the context binary, with session option ep.share_ep_contexts enabled
// Ort sessions will share the QnnBackendManager, so that all graphs from all models compile into the same Qnn context
TEST_F(QnnHTPBackendTests, QnnContextGenWeightSharingSessionAPI) {
ProviderOptions provider_options;
#if defined(_WIN32)
provider_options["backend_path"] = "QnnHtp.dll";
#else
provider_options["backend_path"] = "libQnnHtp.so";
#endif
provider_options["offload_graph_io_quantization"] = "0";

// Create QDQ models
std::vector<std::string> onnx_model_paths{"./weight_share1.onnx", "./weight_share2.onnx"};
std::vector<std::string> ctx_model_paths;
for (auto model_path : onnx_model_paths) {
CreateQdqModel(model_path, DefaultLoggingManager().DefaultLogger());
EXPECT_TRUE(std::filesystem::exists(model_path.c_str()));
ctx_model_paths.push_back(model_path + "_ctx.onnx");
}

Ort::SessionOptions so;
so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1");
so.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0");
// enable ep.share_ep_contexts so that QNNEP share the QnnBackendManager across sessions
so.AddConfigEntry(kOrtSessionOptionShareEpContexts, "1");

so.AppendExecutionProvider("QNN", provider_options);

Ort::Session session1(*ort_env, ToPathString(onnx_model_paths[0]).c_str(), so);
std::string qnn_ctx_binary_file_name1;
GetContextBinaryFileName(ctx_model_paths[0], qnn_ctx_binary_file_name1,
DefaultLoggingManager().DefaultLogger());
EXPECT_TRUE(qnn_ctx_binary_file_name1.empty());

Ort::Session session2(*ort_env, ToPathString(onnx_model_paths[1]).c_str(), so);
std::string qnn_ctx_binary_file_name2;
GetContextBinaryFileName(ctx_model_paths[1], qnn_ctx_binary_file_name2,
DefaultLoggingManager().DefaultLogger());
EXPECT_TRUE(qnn_ctx_binary_file_name2.empty());

auto file_size_1 = std::filesystem::file_size(qnn_ctx_binary_file_name1);
auto file_size_2 = std::filesystem::file_size(qnn_ctx_binary_file_name2);
EXPECT_TRUE(file_size_2 > file_size_1);

// clean up
for (auto model_path : onnx_model_paths) {
ASSERT_EQ(std::remove(model_path.c_str()), 0);
}
for (auto ctx_model_path : ctx_model_paths) {
ASSERT_EQ(std::remove(ctx_model_path.c_str()), 0);
}
ASSERT_EQ(std::remove(qnn_ctx_binary_file_name1.c_str()), 0);
ASSERT_EQ(std::remove(qnn_ctx_binary_file_name2.c_str()), 0);
}
#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)

} // namespace test
Expand Down

0 comments on commit c8eb395

Please sign in to comment.