diff --git a/tools/pnnx/Releasenotes b/tools/pnnx/Releasenotes index c4e432f1c9e..87a8167f2af 100644 --- a/tools/pnnx/Releasenotes +++ b/tools/pnnx/Releasenotes @@ -44,4 +44,7 @@ dev.1.0.12.20240529 1. Add getInputType function in infer py dev.1.0.13.20240530 -1. Trans string to char in getInputType function \ No newline at end of file +1. Trans string to char in getInputType function + +dev.1.0.14.20240531 +1. Fix bug of make_index_expression for gen tensor.index infer op \ No newline at end of file diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 9d7f9dc866e..e854fd381cb 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -1406,21 +1406,48 @@ static std::string make_index_expression(const Operator* op) std::string index_expr = op->params.at("expr").s; // strip out-most [ ] pair - index_expr = index_expr.substr(1, index_expr.size() - 2); + // index_expr = index_expr.substr(1, index_expr.size() - 2); - // None,None, -> ..., - bool leading_none = false; + // // None,None, -> ..., + // bool leading_none = false; + // while (index_expr.substr(0, 5) == "None,") + // { + // leading_none = true; + // index_expr = index_expr.substr(5); + // } + // if (leading_none) + // { + // index_expr = "...," + index_expr; + // } + + // return index_expr; + std::vector shape = op->inputs.at(0)->shape; + std::string out_index_expr = ""; + index_expr = index_expr.substr(1, index_expr.size() - 2); + int indices_index = 0; while (index_expr.substr(0, 5) == "None,") { - leading_none = true; + index_expr = index_expr.substr(5); + indices_index++; } - if (leading_none) + for(int i = 0; i < shape.size(); i++) { - index_expr = "...," + index_expr; - } + if ( i == indices_index) + { + out_index_expr = out_index_expr + index_expr; - return index_expr; + }else + { + out_index_expr = out_index_expr + ":"; + + } + if ( i != shape.size() - 1) + { + out_index_expr = out_index_expr + ","; + } + } + return out_index_expr; } int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) diff --git a/tools/pnnx/src/py_proj.cpp b/tools/pnnx/src/py_proj.cpp index c38159adb7c..db2661bbcc8 100644 --- a/tools/pnnx/src/py_proj.cpp +++ b/tools/pnnx/src/py_proj.cpp @@ -5,7 +5,7 @@ // #include #define STRINGIFY(x) #x #define MACRO_STRINGIFY(x) STRINGIFY(x) -#define MYLIBRARY_VERSION "dev.1.0.13.20240530" +#define MYLIBRARY_VERSION "dev.1.0.14.20240531" using namespace pnnx_graph; using namespace pnnx_ir; namespace py = pybind11;