Skip to content

Commit

Permalink
Support for upper bound
Browse files Browse the repository at this point in the history
  • Loading branch information
jatinwadhwa921 committed Dec 20, 2024
1 parent 8d79f90 commit 576b124
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 1 deletion.
53 changes: 53 additions & 0 deletions onnxruntime/core/providers/openvino/backends/basic_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
return;

// OV Config

ov::AnyMap device_config;
PopulateConfigValue(device_config);

Expand Down Expand Up @@ -133,6 +134,41 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
inferRequestsQueue_ = std::unique_ptr<InferRequestsQueue>(new InferRequestsQueue(exe_network_, num_infer_req));
}

std::map<std::string, std::vector<std::string>>BasicBackend:: parse_input_shapes(const std::string& parameter_string)
{
std::map<std::string, std::vector<std::string>> return_value;
std:: string search_string = parameter_string;
auto start_pos = search_string.find_first_of('['); //input_1[1,23,4,5],inpu2[1,2,3,4]
auto input_name = search_string.substr(0,start_pos);
while(start_pos != std::string::npos) {
auto end_pos = search_string.find_first_of(']');
if(end_pos == std::string::npos){
break;
}
if(start_pos){
input_name = search_string.substr(0,start_pos);
}
auto input_value = search_string.substr(start_pos+1,end_pos - start_pos - 1);
if(!input_name.empty()) {
return_value[input_name].push_back(input_value);
} else{
ORT_THROW("Please provide with a valid input name in the shape parameter");
}
search_string = search_string.substr(end_pos+1);
if(search_string.empty() || search_string.front() != ',') {
break;
}
if(search_string.front()==',') {
search_string = search_string.substr(1);
}
start_pos = search_string.find_first_of('[');
}
if(!search_string.empty()) {
ORT_THROW("CANNOT PARSE INPUT PARAMETER STRING: "+ parameter_string);
}
return return_value;
}

bool BasicBackend::ValidateSubgraph(std::map<std::string, std::shared_ptr<ov::Node>>& const_outputs_map) {
if (const_outputs_map.size() == subgraph_context_.output_names.size())
subgraph_context_.is_constant = true;
Expand Down Expand Up @@ -190,6 +226,23 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) {
#endif
}

if(!global_context_.reshape_input.empty())
{
//input1 => ["123","256,"76..457]
//input2=>[]
//shape_map=> should be in global context

global_context_.Shape_map = parse_input_shapes(global_context_.reshape_input);

for(const auto&[key,vec] : global_context_.Shape_map) {
if (vec.size()>1) {
ORT_THROW("shape command line parameter doesn't support multiple shapes for one input.");
}
}


}

if (!global_context_.load_config.empty()) {
const std::map<std::string, ov::AnyMap>& target_config = global_context_.load_config;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class BasicBackend : public IBackend {
void EnableStreams();
void SetNumThreads(ov::AnyMap& device_config);
void StartAsyncInference(Ort::KernelContext& context, std::shared_ptr<OVInferRequest> infer_request);
std::map<std::string, std::vector<std::string>> parse_input_shapes(const std::string&);

#ifdef IO_BUFFER_ENABLED
void StartRemoteAsyncInference(Ort::KernelContext& context, std::shared_ptr<OVInferRequest> infer_request);
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/openvino/contexts.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ struct GlobalContext {
std::string precision_str;
std::string model_precision;
std::string cache_dir;
std:: string reshape_input;
std::map<std::string, std::vector<std::string>> Shape_map;
std::map<std::string, ov::AnyMap> load_config;
std::string model_priority = "DEFAULT";
int num_streams;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ struct OpenVINOExecutionProviderInfo {
std::map<std::string, ov::AnyMap> load_config_{};
std::string cache_dir_{""};
std::string model_priority_{""};
std:: string reshape_input_{""};
int num_streams_{1};
void* context_{NULL};
bool enable_opencl_throttling_{false};
Expand All @@ -98,6 +99,7 @@ struct OpenVINOExecutionProviderInfo {
size_t num_of_threads,
const std::map<std::string, ov::AnyMap>& load_config,
const std::string& cache_dir,
const std::string& reshape_input,
const std::string& model_priority, int num_streams,
void* context, bool enable_opencl_throttling,
bool disable_dynamic_shapes, bool export_ep_ctx_blob,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ struct OpenVINOProviderFactory : IExecutionProviderFactory {
const std::map<std::string, ov::AnyMap> load_config_;
std::string cache_dir_;
std::string model_priority_;
std::string reshape_input_;
int num_streams_;
void* context_;
bool enable_opencl_throttling_;
Expand Down Expand Up @@ -74,7 +75,7 @@ std::unique_ptr<IExecutionProvider> OpenVINOProviderFactory::CreateProvider() {
}

OpenVINOExecutionProviderInfo info(device_type_, precision_, num_of_threads_, load_config_,
cache_dir_, model_priority_, num_streams_, context_, enable_opencl_throttling_,
cache_dir_, model_priority_, reshape_input_, num_streams_, context_, enable_opencl_throttling_,
disable_dynamic_shapes_, so_export_ep_ctx_blob, enable_qdq_optimizer_,
so_disable_cpu_fallback, so_epctx_embed_mode);
return std::make_unique<OpenVINOExecutionProvider>(info);
Expand Down Expand Up @@ -115,6 +116,7 @@ struct OpenVINO_Provider : Provider {
// dump and load the blobs for the model caching/kernel caching
// (GPU) feature. If blob files are already present,
// it will be directly loaded.
std:: string reshape_input = ""; // Sets the range for models with dynamic input shape.
std::string model_priority = "DEFAULT"; // High-level OpenVINO model priority hint
// Defines what model should be provided with more performant
// bounded resource first
Expand Down Expand Up @@ -203,6 +205,10 @@ struct OpenVINO_Provider : Provider {
cache_dir = provider_options_map.at("cache_dir");
}

if(provider_options_map.find("reshape_input")!= provider_options_map.end()) {
reshape_input = provider_options_map.at("reshape_input");
}

if (provider_options_map.find("load_config") != provider_options_map.end()) {
auto parse_config = [&](const std::string& config_str) -> std::map<std::string, ov::AnyMap> {
// If the config string is empty, return an empty map and skip processing
Expand Down

0 comments on commit 576b124

Please sign in to comment.