Skip to content

Commit

Permalink
Fix bug of make_index_expression for gen tensor.index infer op
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed May 31, 2024
1 parent 06cb344 commit c63c00c
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 10 deletions.
5 changes: 4 additions & 1 deletion tools/pnnx/Releasenotes
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,7 @@ dev.1.0.12.20240529
1. Add getInputType function in infer py

dev.1.0.13.20240530
1. Trans string to char in getInputType function
1. Trans string to char in getInputType function

dev.1.0.14.20240531
1. Fix bug of make_index_expression for gen tensor.index infer op
43 changes: 35 additions & 8 deletions tools/pnnx/src/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1406,21 +1406,48 @@ static std::string make_index_expression(const Operator* op)
std::string index_expr = op->params.at("expr").s;

// strip out-most [ ] pair
index_expr = index_expr.substr(1, index_expr.size() - 2);
// index_expr = index_expr.substr(1, index_expr.size() - 2);

// None,None, -> ...,
bool leading_none = false;
// // None,None, -> ...,
// bool leading_none = false;
// while (index_expr.substr(0, 5) == "None,")
// {
// leading_none = true;
// index_expr = index_expr.substr(5);
// }
// if (leading_none)
// {
// index_expr = "...," + index_expr;
// }

// return index_expr;
std::vector<int> shape = op->inputs.at(0)->shape;
std::string out_index_expr = "";
index_expr = index_expr.substr(1, index_expr.size() - 2);
int indices_index = 0;
while (index_expr.substr(0, 5) == "None,")
{
leading_none = true;

index_expr = index_expr.substr(5);
indices_index++;
}
if (leading_none)
for(int i = 0; i < shape.size(); i++)
{
index_expr = "...," + index_expr;
}
if ( i == indices_index)
{
out_index_expr = out_index_expr + index_expr;

return index_expr;
}else
{
out_index_expr = out_index_expr + ":";

}
if ( i != shape.size() - 1)
{
out_index_expr = out_index_expr + ",";
}
}
return out_index_expr;
}

int Graph::python(const std::string& pypath, const std::string& pnnxbinpath)
Expand Down
2 changes: 1 addition & 1 deletion tools/pnnx/src/py_proj.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
// #include <torch/extension.h>
#define STRINGIFY(x) #x
#define MACRO_STRINGIFY(x) STRINGIFY(x)
#define MYLIBRARY_VERSION "dev.1.0.13.20240530"
#define MYLIBRARY_VERSION "dev.1.0.14.20240531"
using namespace pnnx_graph;
using namespace pnnx_ir;
namespace py = pybind11;
Expand Down

0 comments on commit c63c00c

Please sign in to comment.