Skip to content

Commit

Permalink
update save model name
Browse files Browse the repository at this point in the history
  • Loading branch information
sen.li committed Jun 28, 2024
1 parent 909f165 commit 76829c6
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 12 deletions.
24 changes: 22 additions & 2 deletions tools/pnnx/src/load_torchscript.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,26 @@ int64_t cuda_version();
// #include "pass_level1_class.h"
namespace pnnx {

static std::string get_modelname(const std::string& path)
{
std::string dirpath;
std::string filename;

size_t dirpos = path.find_last_of("/\\");
if (dirpos != std::string::npos)
{
dirpath = path.substr(0, dirpos + 1);
filename = path.substr(dirpos + 1);
}
else
{
filename = path;
}

std::string base = filename.substr(0, filename.find_last_of('.'));
return base;
}

static bool fileExists(const std::string& path) {
FILE* file = fopen(path.c_str(), "r");
if (file) {
Expand Down Expand Up @@ -625,9 +645,9 @@ int load_torchscript(const std::string& ptpath,
#endif

fprintf(stderr, "############# pass_level1\n");

std::string model_name = get_modelname(ptpath);
pnnx::PassLevel1 pass_level1_class;
pass_level1_class.Process(mod, g, module_operators, pnnx_graph);
pass_level1_class.Process(mod, g, module_operators, pnnx_graph, model_name);
return 0;
}

Expand Down
10 changes: 5 additions & 5 deletions tools/pnnx/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ int main(int argc, char** argv)
std::string ptpath = std::string(argv[1]);
std::string save_dir = std::string(argv[2]);
std::string ptbase = get_basename(ptpath);

// std::string pnnxparampath = ptbase + ".pnnx.param";
// std::string pnnxbinpath = ptbase + ".pnnx.bin";
// std::string pnnxpypath = ptbase + "_pnnx.py";
Expand Down Expand Up @@ -371,8 +371,8 @@ int main(int argc, char** argv)
main_graph_queue.push(pair.second);
}
std::string graph_name = cur_main_graph->name;
if(graph_name == "src")
graph_name = "model";
// if(graph_name == "src")
// graph_name = "model";

fprintf(stderr, "############# pass_level2 at %s\n", graph_name.c_str());
pnnx::pass_level2(graph);
Expand Down Expand Up @@ -433,8 +433,8 @@ int main(int argc, char** argv)

}
std::string graph_name = cur_main_graph2->name;
if(graph_name == "src")
graph_name = "model";
// if(graph_name == "src")
// graph_name = "model";

std::string pnnxparampath = save_dir + "/" + graph_name + ".pnnx.param";
std::string pnnxbinpath = save_dir + "/" + graph_name + ".pnnx.bin";
Expand Down
8 changes: 4 additions & 4 deletions tools/pnnx/src/pass_level1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,11 @@ static void fuse_moduleop_unpack(std::shared_ptr<Graph>& graph, const std::vecto
void PassLevel1::Process(const torch::jit::Module& mod,
const std::shared_ptr<torch::jit::Graph>& g,
const std::vector<std::string>& module_operators,
std::shared_ptr<pnnx::MainGraph>& pnnx_graph)
std::shared_ptr<pnnx::MainGraph>& pnnx_graph,
std::string& model_name)
{
// create main graph
std::string main_graph_name = "src";
pnnx_graph->create_main_graph(main_graph_name);
pnnx_graph->create_main_graph(model_name);
std::shared_ptr<pnnx::Graph> pg = pnnx_graph->get_main_graph();
this->_module_operators = module_operators;

Expand Down Expand Up @@ -485,10 +485,10 @@ void PassLevel1::Process_Loop(const torch::jit::Module& mod,
pnnx_graph->effective_sub_model_name.push_back(loop_block_name);
block_num++;
block_names.push_back(loop_block_name);

}



}


Expand Down
3 changes: 2 additions & 1 deletion tools/pnnx/src/pass_level1.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ class PassLevel1
void Process(const torch::jit::Module& mod,
const std::shared_ptr<torch::jit::Graph>& g,
const std::vector<std::string>& module_operators,
std::shared_ptr<pnnx::MainGraph>& pnnx_graph);
std::shared_ptr<pnnx::MainGraph>& pnnx_graph,
std::string& model_name);
private:


Expand Down

0 comments on commit 76829c6

Please sign in to comment.