From 33c50c14e6eba5ad6410819ac680785040f94457 Mon Sep 17 00:00:00 2001 From: "sen.li" Date: Thu, 11 Jul 2024 14:36:01 +0800 Subject: [PATCH] 1. Fix the bug of trans_expression2TupleConstruct 2. Add skip_pass_level6 and only_save_main mod 3.Process multi prim::TupleConstruct node at output --- tools/pnnx/Releasenotes | 25 ++- tools/pnnx/src/ir.cpp | 154 +++++++++++++----- tools/pnnx/src/ir.h | 2 +- tools/pnnx/src/main.cpp | 79 +++++++-- tools/pnnx/src/parse/pnnx_graph_parse.cpp | 88 +++++++--- tools/pnnx/src/parse/pnnx_graph_parse.h | 11 +- .../src/pass_level3/fuse_index_expression.cpp | 2 +- .../trans_expression2TupleConstruct.cpp | 100 ++++++++++-- tools/pnnx/src/py_proj.cpp | 9 +- 9 files changed, 356 insertions(+), 114 deletions(-) diff --git a/tools/pnnx/Releasenotes b/tools/pnnx/Releasenotes index 65b4e4b504e1..a2f961ff9ccb 100644 --- a/tools/pnnx/Releasenotes +++ b/tools/pnnx/Releasenotes @@ -22,6 +22,7 @@ dev.1.0.5.20240508 1. Synchronize the main ncnn repository 2. Fix missing approximate parameters of nn.GELU + dev.1.0.6.20240511 1. Add new pass trans_Stack2Unsqueeze, When using torch.stack with a single input and effectively achieving the same result as torch.unsqueeze @@ -47,7 +48,7 @@ dev.1.0.13.20240530 1. Trans string to char in getInputType function dev.1.0.14.20240531 -1. Fix bug of make_index_expression for gen tensor.index infer op +1. Fix bug of make_index_expression for gen tensor.index infer op dev.1.0.15.20240603 1. Support parse Tensor.reshape_as @@ -55,6 +56,7 @@ dev.1.0.15.20240603 dev.1.0.16.20240605 1. fix bug of Tensor.index with two inputs + dev.1.0.17.20240606 1. Add trans_TensorTypeAs2TensorTo pass in pass level 7 @@ -62,17 +64,24 @@ dev.1.0.17.20240606 dev.1.0.18.20240613 1. Skip conv2d nodes of type NoneType + dev.1.0.19.20240614 1. Add extracting sub graph function -dev.1.0.20.20240617 +dev.1.0.20.20240620 1. Add loop op parse function +2. Support export sub_model +3. Support load input tensor to export +4. Support torchvision.ops.nms + -dev.1.0.21.20240619 -1. Support export sub_model +dev.1.0.21.20240627 +1. Support parse multi block +2. Support if block -dev.1.0.22.20240620 -1. Support load input tensor to export +dev.1.0.22.20240709 +1. Process multi prim::TupleConstruct node at output -dev.1.0.23.20240627 -1. Support If/Loop block \ No newline at end of file +dev.1.0.23.20240711 +1. Fix the bug of trans_expression2TupleConstruct +2. Add skip_pass_level6 and only_save_main mod \ No newline at end of file diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 53f713817494..10ba30b134ca 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -1432,10 +1432,24 @@ static std::string make_index_expression(const Operator* op) indices_index++; } size_t pos = 0; - if ((pos = index_expr.find("@")) != std::string::npos) { + while((pos = index_expr.find("@")) != std::string::npos) { index_expr.replace(pos, 1, "v_"); } - for(int i = 0; i < shape.size(); i++) + int input_size = op->inputs.size(); + int loop_num = 0; + if(input_size == 1) + { + int indice_num = op->params.at("indice_num").i; + loop_num = shape.size() - indice_num + 1; + } + else + { + loop_num = shape.size() - (input_size - 1) + 1; + } + + // fprintf(stderr, "############# indice_num: %s\n", std::to_string(indice_num).c_str()); + // fprintf(stderr, "############# loop_num: %s\n", std::to_string(loop_num).c_str()); + for(int i = 0; i < loop_num; i++) { if ( i == indices_index) { @@ -1446,7 +1460,7 @@ static std::string make_index_expression(const Operator* op) out_index_expr = out_index_expr + ":"; } - if ( i != shape.size() - 1) + if ( i != loop_num - 1) { out_index_expr = out_index_expr + ","; } @@ -1822,20 +1836,30 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) } else if (op->type == "Tensor.index") { - // index expr - // if (op->inputs.size() == 2) - // { - // std::string expanded_expr = expand_expression(op->inputs[1]->producer); - // fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), expanded_expr.c_str()); - // } - // else - // { - // std::string index_expr = make_index_expression(op); - // fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), index_expr.c_str()); - // } - std::string index_expr = make_index_expression(op); - fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), index_expr.c_str()); - + if(!skip_pass_level6) + { + + fprintf(stderr, "############# gen python with Tensor.index at %s\n", op->name.c_str()); + + std::string index_expr = make_index_expression(op); + fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), index_expr.c_str()); + + } + else + { + fprintf(stderr, "############# gen python with Tensor.index at %s\n", op->name.c_str()); + // index expr + if (op->inputs.size() == 2) + { + std::string expanded_expr = expand_expression(op->inputs[1]->producer); + fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), expanded_expr.c_str()); + } + else + { + std::string index_expr = make_index_expression(op); + fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), index_expr.c_str()); + } + } } else if (op->type == "Tensor.expand") { @@ -3400,19 +3424,30 @@ int Graph::python_infer(const std::string& pypath, const std::string& binpath, } else if (op->type == "Tensor.index") { - // index expr - // if (op->inputs.size() == 2) - // { - // std::string expanded_expr = expand_expression(op->inputs[1]->producer); - // fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), expanded_expr.c_str()); - // } - // else - // { - // std::string index_expr = make_index_expression(op); - // fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), index_expr.c_str()); - // } - std::string index_expr = make_index_expression(op); - fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), index_expr.c_str()); + if(!skip_pass_level6) + { + + fprintf(stderr, "############# gen python with Tensor.index at %s\n", op->name.c_str()); + + std::string index_expr = make_index_expression(op); + fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), index_expr.c_str()); + + } + else + { + fprintf(stderr, "############# gen python with Tensor.index at %s\n", op->name.c_str()); + // index expr + if (op->inputs.size() == 2) + { + std::string expanded_expr = expand_expression(op->inputs[1]->producer); + fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), expanded_expr.c_str()); + } + else + { + std::string index_expr = make_index_expression(op); + fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), index_expr.c_str()); + } + } } else if (op->type == "Tensor.expand") { @@ -4029,29 +4064,64 @@ int Graph::python_infer(const std::string& pypath, const std::string& binpath, // return if pre node type is TupleConstruct, max_tensor_index not add one add by senli[pnnx_infer] { + // bool TupleConstruct_flag = false; + // int max_tensor_index = 0; + // for (const Operator* op : ops) + // { + // if (op->type == "pnnx.Output") + // { + // std::vector inputs = op->inputs; + // for (const Operand* tensor : inputs) + // { + // Operator* pre_op = tensor->producer; + // if (pre_op->type == "prim::TupleConstruct") + // { + // TupleConstruct_flag = true; + // } + // } + // int num = std::stoi(op->inputs[0]->name); + // max_tensor_index = (max_tensor_index > num) ? max_tensor_index : num; + // } + // } + bool TupleConstruct_flag = false; int max_tensor_index = 0; - for (const Operator* op : ops) + std::queue output_queue; + for (auto op : ops) { if (op->type == "pnnx.Output") { - std::vector inputs = op->inputs; - for (const Operand* tensor : inputs) + output_queue.push(op); + break; + } + } + while(!output_queue.empty()) + { + auto cur_output_op = output_queue.front(); + output_queue.pop(); + std::vector inputs = cur_output_op->inputs; + for (const Operand* tensor : inputs) + { + Operator* pre_op = tensor->producer; + if (pre_op->type == "prim::TupleConstruct") { - Operator* pre_op = tensor->producer; - if (pre_op->type == "prim::TupleConstruct") + TupleConstruct_flag = true; + output_queue.push(pre_op); + } + else + { + for(auto out: pre_op->outputs) { - TupleConstruct_flag = true; + int num = std::stoi(out->name); + max_tensor_index = (max_tensor_index > num) ? max_tensor_index : num; } + } - int num = std::stoi(op->inputs[0]->name); - max_tensor_index = (max_tensor_index > num) ? max_tensor_index : num; } } - if (!TupleConstruct_flag) - { - max_tensor_index++; - } + + max_tensor_index++; + fprintf(pyfp, " intermediate = {}\n"); fprintf(pyfp, " for i in range(%d):\n", max_tensor_index); fprintf(pyfp, " key = 'v_' + str(i)\n"); diff --git a/tools/pnnx/src/ir.h b/tools/pnnx/src/ir.h index 88080f217159..2d02c03e7715 100644 --- a/tools/pnnx/src/ir.h +++ b/tools/pnnx/src/ir.h @@ -358,7 +358,7 @@ class Graph std::vector ops; std::vector operands; - + int skip_pass_level6 = 0; private: Graph(const Graph& rhs); Graph& operator=(const Graph& rhs); diff --git a/tools/pnnx/src/main.cpp b/tools/pnnx/src/main.cpp index cfbab67f6dd2..a14b66d292e9 100644 --- a/tools/pnnx/src/main.cpp +++ b/tools/pnnx/src/main.cpp @@ -235,6 +235,8 @@ int main(int argc, char** argv) std::vector start_nodes; std::vector end_nodes; std::string extract_model_name = "model"; + int skip_pass_level6 = 0; + int only_save_main = 0; for (int i = 3; i < argc; i++) { // key=value @@ -293,6 +295,10 @@ int main(int argc, char** argv) parse_string_list(value, end_nodes); if (strcmp(key, "extract_model_name") == 0) extract_model_name = value; + if (strcmp(key, "skip_pass_level6") == 0) + skip_pass_level6 = atoi(value); + if (strcmp(key, "only_save_main") == 0) + only_save_main = atoi(value); } @@ -336,6 +342,10 @@ int main(int argc, char** argv) fprintf(stderr, "\n"); fprintf(stderr, "extract_model_name = %s\n", extract_model_name.c_str()); fprintf(stderr, "\n"); + fprintf(stderr, "skip_pass_level6 = %d\n", skip_pass_level6); + fprintf(stderr, "\n"); + fprintf(stderr, "only_save_main = %d\n", only_save_main); + fprintf(stderr, "\n"); } @@ -361,15 +371,23 @@ int main(int argc, char** argv) // loop all graph tp pass std::queue> main_graph_queue; main_graph_queue.push(pnnx_graph); - while( !main_graph_queue.empty()) + while( !main_graph_queue.empty()) { std::shared_ptr cur_main_graph = main_graph_queue.front(); main_graph_queue.pop(); std::shared_ptr graph = cur_main_graph->get_main_graph(); - for(auto pair: cur_main_graph->sub_graph_map) + if(!only_save_main) + { + for(auto pair: cur_main_graph->sub_graph_map) + { + main_graph_queue.push(pair.second); + } + } + else { - main_graph_queue.push(pair.second); - } + fprintf(stderr, "############# only pass main model\n"); + } + std::string graph_name = cur_main_graph->name; // if(graph_name == "src") // graph_name = "model"; @@ -401,17 +419,34 @@ int main(int argc, char** argv) pnnx::pass_level5(graph, foldable_constants, foldable_constants_zippath); - // add by senli 20240321 - fprintf(stderr, "############# pass_level6 at %s\n", graph_name.c_str()); + if(!skip_pass_level6) + { + // add by senli 20240321 + fprintf(stderr, "############# pass_level6 at %s\n", graph_name.c_str()); + + pnnx::pass_level6(graph, foldable_constants, foldable_constants_zippath); - pnnx::pass_level6(graph, foldable_constants, foldable_constants_zippath); + } + else + { + fprintf(stderr, "############# skip pass_level6 at %s\n", graph_name.c_str()); + } + } } - // sub_graph_pass - fprintf(stderr, "############# pass_sub_model\n"); - pnnx::pass_sub_model(pnnx_graph); + if(!only_save_main) + { + // sub_graph_pass + fprintf(stderr, "############# pass_sub_model\n"); + pnnx::pass_sub_model(pnnx_graph); + } + else + { + fprintf(stderr, "############# not need to pass sub model\n"); + } + // save graph std::queue> main_graph_queue2; @@ -422,16 +457,26 @@ int main(int argc, char** argv) std::shared_ptr cur_main_graph2 = main_graph_queue2.front(); main_graph_queue2.pop(); std::shared_ptr graph2 = cur_main_graph2->get_main_graph(); + graph2->skip_pass_level6 = skip_pass_level6; - for(auto pair2: cur_main_graph2->sub_graph_map) + if(!only_save_main) { - auto it = std::find(cur_main_graph2->effective_sub_model_name.begin(), cur_main_graph2->effective_sub_model_name.end(),pair2.first); - if(it != cur_main_graph2->effective_sub_model_name.end()) + for(auto pair2: cur_main_graph2->sub_graph_map) { - main_graph_queue2.push(pair2.second); - } - - } + auto it = std::find(cur_main_graph2->effective_sub_model_name.begin(), cur_main_graph2->effective_sub_model_name.end(),pair2.first); + if(it != cur_main_graph2->effective_sub_model_name.end()) + { + main_graph_queue2.push(pair2.second); + } + + } + } + else + { + fprintf(stderr, "############# only save main model\n"); + } + + std::string graph_name = cur_main_graph2->name; // if(graph_name == "src") // graph_name = "model"; diff --git a/tools/pnnx/src/parse/pnnx_graph_parse.cpp b/tools/pnnx/src/parse/pnnx_graph_parse.cpp index eb842be72da7..1174f05a8e24 100644 --- a/tools/pnnx/src/parse/pnnx_graph_parse.cpp +++ b/tools/pnnx/src/parse/pnnx_graph_parse.cpp @@ -2,23 +2,44 @@ int main(int argc, char** argv); namespace pnnx_graph { + +static std::string getDirectory(const std::string& path) +{ + std::string dirpath; + std::string filename; + + size_t dirpos = path.find_last_of("/\\"); + if (dirpos != std::string::npos) + { + dirpath = path.substr(0, dirpos + 1); + filename = path.substr(dirpos + 1); + } + else + { + dirpath = ""; + filename = path; + } + return dirpath; +} + bool PnnxGraph::getNvpPnnxModel(const std::string& pt_path, const std::string& input_shape, const std::string& custom_op_path, - const std::string& custom_op_py, const std::string& start_nodes, const std::string& end_nodes, const std::string& extract_model_name) + const std::string& custom_op_py, const std::string& start_nodes, const std::string& end_nodes, + const std::string& extract_model_name, const std::string& skip_pass_level6, const std::string& only_save_main) { int argc; char** argv; if (custom_op_path != "None" && custom_op_py != "None") { - argc = 9; + argc = 11; } else if (custom_op_path != "None" && custom_op_py == "None") { - argc = 8; + argc = 10; } else if (custom_op_path == "None" && custom_op_py == "None") { - argc = 7; + argc = 9; } argv = new char*[argc]; @@ -55,20 +76,31 @@ bool PnnxGraph::getNvpPnnxModel(const std::string& pt_path, const std::string& i std::strcpy(argv[5], custom_op_py_info.c_str()); } - //insert start nodes + //insert start nodes std::string stard_nodes_info = "start_nodes=" + start_nodes; - argv[argc - 3] = new char[stard_nodes_info.size() + 1]; - std::strcpy( argv[argc - 3], stard_nodes_info.c_str()); + argv[argc - 5] = new char[stard_nodes_info.size() + 1]; + std::strcpy( argv[argc - 5], stard_nodes_info.c_str()); //insert end nodes std::string end_nodes_info = "end_nodes=" + end_nodes; - argv[argc - 2] = new char[end_nodes_info.size() + 1]; - std::strcpy( argv[argc - 2], end_nodes_info.c_str()); + argv[argc - 4] = new char[end_nodes_info.size() + 1]; + std::strcpy( argv[argc - 4], end_nodes_info.c_str()); //insert extract_model_name std::string extract_model_name_info = "extract_model_name=" + extract_model_name; - argv[argc - 1] = new char[extract_model_name_info.size() + 1]; - std::strcpy(argv[argc - 1], extract_model_name_info.c_str()); + argv[argc - 3] = new char[extract_model_name_info.size() + 1]; + std::strcpy(argv[argc - 3], extract_model_name_info.c_str()); + + + //insert skip_pass_level6 + std::string skip_pass_level6_info = "skip_pass_level6=" + skip_pass_level6; + argv[argc - 2] = new char[skip_pass_level6_info.size() + 1]; + std::strcpy(argv[argc - 2], skip_pass_level6_info.c_str()); + + //insert only_save_main + std::string only_save_main_info = "only_save_main=" + only_save_main; + argv[argc - 1] = new char[only_save_main_info.size() + 1]; + std::strcpy(argv[argc - 1], only_save_main_info.c_str()); int result = main(argc, argv); @@ -87,22 +119,23 @@ bool PnnxGraph::getNvpPnnxModel(const std::string& pt_path, const std::string& i } bool PnnxGraph::getNvpPnnxModelV1(const std::string& pt_path, const std::string& save_dir, const std::string& input_shape, const std::string& custom_op_path, - const std::string& custom_op_py, const std::string& start_nodes, const std::string& end_nodes, const std::string& extract_model_name) + const std::string& custom_op_py, const std::string& start_nodes, const std::string& end_nodes, const std::string& extract_model_name, + const std::string& skip_pass_level6, const std::string& only_save_main) { int argc; char** argv; if (custom_op_path != "None" && custom_op_py != "None") { - argc = 9; + argc = 11; } else if (custom_op_path != "None" && custom_op_py == "None") { - argc = 8; + argc = 10; } else if (custom_op_path == "None" && custom_op_py == "None") { - argc = 7; + argc = 9; } argv = new char*[argc]; @@ -140,18 +173,31 @@ bool PnnxGraph::getNvpPnnxModelV1(const std::string& pt_path, const std::string& //insert start nodes std::string stard_nodes_info = "start_nodes=" + start_nodes; - argv[argc - 3] = new char[stard_nodes_info.size() + 1]; - std::strcpy( argv[argc - 3], stard_nodes_info.c_str()); + argv[argc - 5] = new char[stard_nodes_info.size() + 1]; + std::strcpy( argv[argc - 5], stard_nodes_info.c_str()); //insert end nodes std::string end_nodes_info = "end_nodes=" + end_nodes; - argv[argc - 2] = new char[end_nodes_info.size() + 1]; - std::strcpy( argv[argc - 2], end_nodes_info.c_str()); + argv[argc - 4] = new char[end_nodes_info.size() + 1]; + std::strcpy( argv[argc - 4], end_nodes_info.c_str()); //insert extract_model_name std::string extract_model_name_info = "extract_model_name=" + extract_model_name; - argv[argc - 1] = new char[extract_model_name_info.size() + 1]; - std::strcpy(argv[argc - 1], extract_model_name_info.c_str()); + argv[argc - 3] = new char[extract_model_name_info.size() + 1]; + std::strcpy(argv[argc - 3], extract_model_name_info.c_str()); + + + //insert skip_pass_level6 + std::string skip_pass_level6_info = "skip_pass_level6=" + skip_pass_level6; + argv[argc - 2] = new char[skip_pass_level6_info.size() + 1]; + std::strcpy(argv[argc - 2], skip_pass_level6_info.c_str()); + + //insert only_save_main + std::string only_save_main_info = "only_save_main=" + only_save_main; + argv[argc - 1] = new char[only_save_main_info.size() + 1]; + std::strcpy(argv[argc - 1], only_save_main_info.c_str()); + + int result = main(argc, argv); diff --git a/tools/pnnx/src/parse/pnnx_graph_parse.h b/tools/pnnx/src/parse/pnnx_graph_parse.h index 26f608868139..9a1e40c11bde 100644 --- a/tools/pnnx/src/parse/pnnx_graph_parse.h +++ b/tools/pnnx/src/parse/pnnx_graph_parse.h @@ -11,7 +11,7 @@ class PnnxGraph { public: -/** + /** * @brief Get the Nvp Pnnx Model object * * @param pt_path torchscript path @@ -30,7 +30,9 @@ class PnnxGraph const std::string& custom_op_py,\ const std::string& start_nodes = "",\ const std::string& end_nodes = "",\ - const std::string& extract_model_name = "model"); + const std::string& extract_model_name = "model", + const std::string& skip_pass_level6 = "0", + const std::string& only_save_main = "0"); /** * @brief Get the Nvp Pnnx Model object @@ -53,8 +55,9 @@ class PnnxGraph const std::string& custom_op_py,\ const std::string& start_nodes = "",\ const std::string& end_nodes = "",\ - const std::string& extract_model_name = "model"); - + const std::string& extract_model_name = "model", + const std::string& skip_pass_level6 = "0", + const std::string& only_save_main = "0"); diff --git a/tools/pnnx/src/pass_level3/fuse_index_expression.cpp b/tools/pnnx/src/pass_level3/fuse_index_expression.cpp index 718d86970042..18691382771e 100644 --- a/tools/pnnx/src/pass_level3/fuse_index_expression.cpp +++ b/tools/pnnx/src/pass_level3/fuse_index_expression.cpp @@ -170,7 +170,7 @@ void fuse_index_expression(std::shared_ptr graph) matched = true; std::string expr = fuse_attribute_expression(op2); - + op->params["indice_num"] = int(op2->inputs.size()); op->params["expr"] = expr; op->inputs[1]->producer = 0; diff --git a/tools/pnnx/src/pass_level6/trans_expression2TupleConstruct.cpp b/tools/pnnx/src/pass_level6/trans_expression2TupleConstruct.cpp index 257594d9bcf6..78f073d6664e 100644 --- a/tools/pnnx/src/pass_level6/trans_expression2TupleConstruct.cpp +++ b/tools/pnnx/src/pass_level6/trans_expression2TupleConstruct.cpp @@ -36,9 +36,10 @@ void trans_expression2TupleConstruct(std::shared_ptr graph) { Parameter param = op->params["expr"]; std::string expr = param.s; - // printf("op_name:%s\n",op->name.c_str()); + if (expr.front() == '[' && expr.back() == ']') { + printf("op_name:%s\n",op->name.c_str()); matched = true; std::vector outputs = op->outputs; bool sink_node_is_index = false; @@ -46,41 +47,108 @@ void trans_expression2TupleConstruct(std::shared_ptr graph) { sink_node_is_index = true; } - if (sink_node_is_index) { // update expr - std::string out_operand_name = outputs[0]->name; + int input_num = op->inputs.size(); + std::vector cur_op_inputs = op->inputs; size_t pos = 0; - if((pos = expr.find("0")) != std::string::npos) + for(int i = 0; i < input_num; i++) + { + std::string index = std::to_string(i); + std::string operand_name = cur_op_inputs[i]->name; + pos = expr.find('@',pos); + if(pos != std::string::npos) + { + expr.replace(pos+1, 1, operand_name); + } + pos += 1; + } + for(auto out: outputs) + { + for(auto consumer: out->consumers) + { + consumer->params["expr"] = expr; + consumer->inputs.insert(consumer->inputs.end(), cur_op_inputs.begin(), cur_op_inputs.end()); + for(auto input: cur_op_inputs) + { + input->consumers.push_back(consumer); + } + + } + } + // delete cur op and out_operand + for(auto input: cur_op_inputs) { - expr.replace(pos, 1, out_operand_name); + input->consumers.erase(std::find(input->consumers.begin(), input->consumers.end(), op)); } - outputs[0]->consumers[0]->params["expr"] = expr; - Operand* input = op->inputs[0]; - Operator* pre_node = input->producer; - pre_node->outputs.clear(); - for (auto& single_out : outputs) + for(auto out: outputs) { - single_out->producer = pre_node; - pre_node->outputs.push_back(single_out); + for(auto consumer: out->consumers) + { + consumer->inputs.erase(std::find(consumer->inputs.begin(), consumer->inputs.end(), out)); + } } - input->producer = 0; - input->consumers.clear(); - graph->operands.erase(std::find(graph->operands.begin(), graph->operands.end(), input)); - delete input; + Operand* output = op->outputs[0]; + output->producer = 0; + output->consumers.clear(); + graph->operands.erase(std::find(graph->operands.begin(), graph->operands.end(), output)); + delete output; op->inputs.clear(); op->outputs.clear(); graph->ops.erase(graph->ops.begin() + i); delete op; + + + + + + + + + } else { op->type = "prim::TupleConstruct"; op->params.clear(); } + // if (sink_node_is_index) + // { + // // update expr + // std::string out_operand_name = outputs[0]->name; + // size_t pos = 0; + // if((pos = expr.find("0")) != std::string::npos) + // { + // expr.replace(pos, 1, out_operand_name); + // } + // outputs[0]->consumers[0]->params["expr"] = expr; + // Operand* input = op->inputs[0]; + // Operator* pre_node = input->producer; + // pre_node->outputs.clear(); + // for (auto& single_out : outputs) + // { + // single_out->producer = pre_node; + // pre_node->outputs.push_back(single_out); + // } + // input->producer = 0; + // input->consumers.clear(); + // graph->operands.erase(std::find(graph->operands.begin(), graph->operands.end(), input)); + // delete input; + + // op->inputs.clear(); + // op->outputs.clear(); + + // graph->ops.erase(graph->ops.begin() + i); + // delete op; + // } + // else + // { + // op->type = "prim::TupleConstruct"; + // op->params.clear(); + // } break; } diff --git a/tools/pnnx/src/py_proj.cpp b/tools/pnnx/src/py_proj.cpp index d27132cfd683..8fc8b976b56d 100644 --- a/tools/pnnx/src/py_proj.cpp +++ b/tools/pnnx/src/py_proj.cpp @@ -5,7 +5,7 @@ // #include #define STRINGIFY(x) #x #define MACRO_STRINGIFY(x) STRINGIFY(x) -#define MYLIBRARY_VERSION "dev.1.0.23.20240627" +#define MYLIBRARY_VERSION "dev.1.0.23.20240711" using namespace pnnx_graph; using namespace pnnx_ir; namespace py = pybind11; @@ -75,11 +75,12 @@ PYBIND11_MODULE(ptx, m) .def("getNvpPnnxModel", &PnnxGraph::getNvpPnnxModel, py::arg("pt_path"), \ py::arg("input_shape"), py::arg("custom_op_path"), \ py::arg("custom_op_py"), py::arg("start_nodes") = "", py::arg("end_nodes") = "",\ - py::arg("extract_model_name") = "model") + py::arg("extract_model_name") = "model",py::arg("skip_pass_level6") = "0",py::arg("only_save_main") = "0") + .def("getNvpPnnxModelV1", &PnnxGraph::getNvpPnnxModelV1, py::arg("pt_path"), \ py::arg("save_dir"), py::arg("input_shape"), py::arg("custom_op_path"), \ py::arg("custom_op_py"), py::arg("start_nodes") = "", py::arg("end_nodes") = "",\ - py::arg("extract_model_name") = "model") + py::arg("extract_model_name") = "model",py::arg("skip_pass_level6") = "0", py::arg("only_save_main") = "0") .def("loadModel", &PnnxGraph::loadModel) .def("saveModel", &PnnxGraph::saveModel) // .def("getOperators", (std::vector(PnnxGraph::*)()) & PnnxGraph::getOperators) @@ -87,7 +88,7 @@ PYBIND11_MODULE(ptx, m) .def("getOperands", &PnnxGraph::getOperands, py::return_value_policy::reference_internal) .def("getInputOps", &PnnxGraph::getInputOps, py::return_value_policy::reference_internal) .def("getOutputOps", &PnnxGraph::getOutputOps, py::return_value_policy::reference_internal); - + #ifdef VERSION_INFO m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO); #else