Skip to content

Commit

Permalink
1. Add getInputType function in infer py
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed May 29, 2024
1 parent 7cafb2e commit e59bbb2
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 12 deletions.
5 changes: 4 additions & 1 deletion tools/pnnx/Releasenotes
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,7 @@ dev.1.0.10.20240529
1. Support parse multi dim list to string

dev.1.0.11.20240529
1. Support parse F.one_hot
1. Support parse F.one_hot

dev.1.0.12.20240529
1. Add getInputType function in infer py
42 changes: 32 additions & 10 deletions tools/pnnx/src/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3108,24 +3108,28 @@ int Graph::python_infer(const std::string& pypath, const std::string& binpath,

fprintf(pyfp, "\n");

// get input_shape add by senli[pnnx_infer]
// get input_shape and input_type add by senli[pnnx_infer]
{
// example:
// def getInput(self,):
// return [[1, 3, 32, 32],[1,3,64,64]]

fprintf(pyfp, " def getInput(self,):\n");
fprintf(pyfp, " return [");
//获得op的所有输入的shape
std::vector<std::vector<int> > input_shapes;
// get shape and type of the input op
std::vector<std::vector<int>> input_shapes;
std::vector<std::string> input_types;
for (const Operator* op : ops)
{
if (op->type != "pnnx.Input")
continue;
const Operand* r = op->outputs[0];
input_shapes.push_back(r->shape);
input_types.push_back(type_to_string(r->type));
}
//依次写入shape

//insert shape
// example:
// def getInput(self,):
// return [[1, 3, 32, 32],[1,3,64,64]]

fprintf(pyfp, " def getInput(self,):\n");
fprintf(pyfp, " return [");

for (size_t i = 0; i < input_shapes.size(); i++)
{
std::vector<int> one_input_shape = input_shapes[i];
Expand All @@ -3141,6 +3145,24 @@ int Graph::python_infer(const std::string& pypath, const std::string& binpath,
fprintf(pyfp, ", ");
}
fprintf(pyfp, "]\n");

fprintf(pyfp, "\n");

//insert type
// example:
// def getInputType(self,):
// return ['fp32','i64']
fprintf(pyfp, " def getInputType(self,):\n");
fprintf(pyfp, " return [");
for (size_t i = 0; i < input_types.size(); i++)
{
std::string input_type = input_types[i];
fprintf(pyfp, "'%s'", input_type);
if (i + 1 != input_types.size())
fprintf(pyfp, ", ");
}
fprintf(pyfp, "]\n");

}
fprintf(pyfp, "\n");
// utility function
Expand Down
2 changes: 1 addition & 1 deletion tools/pnnx/src/py_proj.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
// #include <torch/extension.h>
#define STRINGIFY(x) #x
#define MACRO_STRINGIFY(x) STRINGIFY(x)
#define MYLIBRARY_VERSION "dev.1.0.11.20240529"
#define MYLIBRARY_VERSION "dev.1.0.12.20240529"
using namespace pnnx_graph;
using namespace pnnx_ir;
namespace py = pybind11;
Expand Down

0 comments on commit e59bbb2

Please sign in to comment.