diff --git a/tools/pnnx/src/load_torchscript.cpp b/tools/pnnx/src/load_torchscript.cpp index 467d7de778c..53c1795f40e 100644 --- a/tools/pnnx/src/load_torchscript.cpp +++ b/tools/pnnx/src/load_torchscript.cpp @@ -34,6 +34,26 @@ int64_t cuda_version(); // #include "pass_level1_class.h" namespace pnnx { +static std::string get_modelname(const std::string& path) +{ + std::string dirpath; + std::string filename; + + size_t dirpos = path.find_last_of("/\\"); + if (dirpos != std::string::npos) + { + dirpath = path.substr(0, dirpos + 1); + filename = path.substr(dirpos + 1); + } + else + { + filename = path; + } + + std::string base = filename.substr(0, filename.find_last_of('.')); + return base; +} + static bool fileExists(const std::string& path) { FILE* file = fopen(path.c_str(), "r"); if (file) { @@ -625,9 +645,9 @@ int load_torchscript(const std::string& ptpath, #endif fprintf(stderr, "############# pass_level1\n"); - + std::string model_name = get_modelname(ptpath); pnnx::PassLevel1 pass_level1_class; - pass_level1_class.Process(mod, g, module_operators, pnnx_graph); + pass_level1_class.Process(mod, g, module_operators, pnnx_graph, model_name); return 0; } diff --git a/tools/pnnx/src/main.cpp b/tools/pnnx/src/main.cpp index 74556c89c51..cfbab67f6dd 100644 --- a/tools/pnnx/src/main.cpp +++ b/tools/pnnx/src/main.cpp @@ -211,7 +211,7 @@ int main(int argc, char** argv) std::string ptpath = std::string(argv[1]); std::string save_dir = std::string(argv[2]); std::string ptbase = get_basename(ptpath); - + // std::string pnnxparampath = ptbase + ".pnnx.param"; // std::string pnnxbinpath = ptbase + ".pnnx.bin"; // std::string pnnxpypath = ptbase + "_pnnx.py"; @@ -371,8 +371,8 @@ int main(int argc, char** argv) main_graph_queue.push(pair.second); } std::string graph_name = cur_main_graph->name; - if(graph_name == "src") - graph_name = "model"; + // if(graph_name == "src") + // graph_name = "model"; fprintf(stderr, "############# pass_level2 at %s\n", graph_name.c_str()); pnnx::pass_level2(graph); @@ -433,8 +433,8 @@ int main(int argc, char** argv) } std::string graph_name = cur_main_graph2->name; - if(graph_name == "src") - graph_name = "model"; + // if(graph_name == "src") + // graph_name = "model"; std::string pnnxparampath = save_dir + "/" + graph_name + ".pnnx.param"; std::string pnnxbinpath = save_dir + "/" + graph_name + ".pnnx.bin"; diff --git a/tools/pnnx/src/pass_level1.cpp b/tools/pnnx/src/pass_level1.cpp index 29293eda967..51c6148c5b0 100644 --- a/tools/pnnx/src/pass_level1.cpp +++ b/tools/pnnx/src/pass_level1.cpp @@ -114,11 +114,11 @@ static void fuse_moduleop_unpack(std::shared_ptr& graph, const std::vecto void PassLevel1::Process(const torch::jit::Module& mod, const std::shared_ptr& g, const std::vector& module_operators, - std::shared_ptr& pnnx_graph) + std::shared_ptr& pnnx_graph, + std::string& model_name) { // create main graph - std::string main_graph_name = "src"; - pnnx_graph->create_main_graph(main_graph_name); + pnnx_graph->create_main_graph(model_name); std::shared_ptr pg = pnnx_graph->get_main_graph(); this->_module_operators = module_operators; @@ -485,10 +485,10 @@ void PassLevel1::Process_Loop(const torch::jit::Module& mod, pnnx_graph->effective_sub_model_name.push_back(loop_block_name); block_num++; block_names.push_back(loop_block_name); - } + } diff --git a/tools/pnnx/src/pass_level1.h b/tools/pnnx/src/pass_level1.h index 6c20a3a011d..0b03e12a92e 100644 --- a/tools/pnnx/src/pass_level1.h +++ b/tools/pnnx/src/pass_level1.h @@ -55,7 +55,8 @@ class PassLevel1 void Process(const torch::jit::Module& mod, const std::shared_ptr& g, const std::vector& module_operators, - std::shared_ptr& pnnx_graph); + std::shared_ptr& pnnx_graph, + std::string& model_name); private: