Skip to content

Commit

Permalink
1. Support multi dim in fuse_index_expression pass for Tensor.index 2…
Browse files Browse the repository at this point in the history
…. Support parse multi dim list to string
  • Loading branch information
root committed May 29, 2024
1 parent 16daa1e commit 8073429
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 29 deletions.
14 changes: 13 additions & 1 deletion tools/pnnx/Releasenotes
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,16 @@ dev.1.0.5.20240508
2. Fix missing approximate parameters of nn.GELU

dev.1.0.6.20240511
1. Add new pass trans_Stack2Unsqueeze, When using torch.stack with a single input and effectively achieving the same result as torch.unsqueeze
1. Add new pass trans_Stack2Unsqueeze, When using torch.stack with a single input and effectively achieving the same result as torch.unsqueeze

dev.1.0.7.20240521
1. add saveModel function

dev.1.0.8.20240526
1. To solve the memory issue in ptx

dev.1.0.9.20240528
1. Support multi dim in fuse_index_expression pass for Tensor.index

dev.1.0.10.20240529
1. Support parse multi dim list to string
32 changes: 31 additions & 1 deletion tools/pnnx/src/parse/pnnx_ir_parse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@
using namespace pnnx;
namespace pnnx_ir {

static size_t countSubstring(const std::string& str, const std::string& substr) {
size_t count = 0;
size_t pos = 0;

while ((pos = str.find(substr, pos)) != std::string::npos) {
++count;
pos += substr.length();
}

return count;
}

static bool type_is_integer(int type)
{
if (type == 1) return false;
Expand Down Expand Up @@ -465,6 +477,25 @@ Attribute operator+(const Attribute& a, const Attribute& b)

Parameter Parameter::parse_from_string(const std::string& value)
{
if (value.find('%') != std::string::npos)
{
Parameter p;
p.type = 4;
p.s = value;
return p;
}
size_t count1 = countSubstring(value, "[");
size_t count2 = countSubstring(value, "]");
size_t count3 = countSubstring(value, "(");
size_t count4 = countSubstring(value, ")");
if(count1 > 1 || count2 > 1 || count3 > 1 || count4 >1)
{
Parameter p;
p.type = 4;
p.s = value;
return p;
}

Parameter p;
p.type = 0;

Expand Down Expand Up @@ -535,7 +566,6 @@ Parameter Parameter::parse_from_string(const std::string& value)
p.i = std::stoi(value);
return p;
}

Graph::Graph()
{
}
Expand Down
65 changes: 57 additions & 8 deletions tools/pnnx/src/pass_level3/fuse_index_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,40 @@ static void replaceAll(std::string& str, const std::string& from, const std::str
}
}

static void multi_expr(int depth, std::vector<int>& attr_shape, const int64_t* pdata, std::string& attr_expr, int pre_depth, int& cur_index)
{
if(depth == attr_shape.size() - 1)
{
attr_expr += "[";
for(int j = 0; j < attr_shape[attr_shape.size()-1]; j++)
{
int64_t n = pdata[cur_index + j];
attr_expr += std::to_string(n);
if( j != attr_shape[attr_shape.size()-1] - 1)
{
attr_expr += ",";
}
}
attr_expr += "]";
cur_index += attr_shape[attr_shape.size()-1];
}
else
{
int cur_dim = attr_shape[depth];
attr_expr += "[";
for(int i=0; i < cur_dim; i++)
{
multi_expr(depth+1, attr_shape, pdata, attr_expr, i, cur_index);
if (i != cur_dim -1)
{
attr_expr += ",";
}

}
attr_expr += "]";

}
}
static std::string fuse_attribute_expression(Operator* op_expr)
{
std::string expr = op_expr->params["expr"].s;
Expand Down Expand Up @@ -57,20 +91,35 @@ static std::string fuse_attribute_expression(Operator* op_expr)
}
attr_expr += "]";
}
// else if (attr.type == 5)
// {
// // i64
// const int64_t* pdata = (const int64_t*)attr.data.data();
// attr_expr += "[";
// for (int j = 0; j < count; j++)
// {
// int64_t n = pdata[j];
// attr_expr += std::to_string(n);

// if (j != count - 1)
// attr_expr += ",";
// }
// attr_expr += "]";
// }
else if (attr.type == 5)
{
// i64
const int64_t* pdata = (const int64_t*)attr.data.data();
attr_expr += "[";
for (int j = 0; j < count; j++)
int depth = 0;
int attr_shape_size = attr.shape.size();
std::vector<int> attr_shape;
for(int i = 0; i < attr_shape_size; i++)
{
int64_t n = pdata[j];
attr_expr += std::to_string(n);

if (j != count - 1)
attr_expr += ",";
attr_shape.push_back(attr.shape[i]);
}
attr_expr += "]";
int pre_depth = 0;
int cur_index = 0;
multi_expr(depth, attr_shape, pdata, attr_expr, pre_depth, cur_index);
}
else
{
Expand Down
2 changes: 1 addition & 1 deletion tools/pnnx/src/py_proj.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
// #include <torch/extension.h>
#define STRINGIFY(x) #x
#define MACRO_STRINGIFY(x) STRINGIFY(x)
#define MYLIBRARY_VERSION "dev.1.0.6.20240511"
#define MYLIBRARY_VERSION "dev.1.0.10.20240529"
using namespace pnnx_graph;
using namespace pnnx_ir;
namespace py = pybind11;
Expand Down
33 changes: 20 additions & 13 deletions tools/pnnx/tools/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
if platform.system() == "Windows":
sys.path.append('D:/project/programs/ncnn_project/nvppnnx/python/build/lib.win-amd64-cpython-38/pnnx')
elif platform.system() == "Linux":
sys.path.append('/workspace/trans_onnx/project/new_project/nvppnnx/python/build/temp.linux-x86_64-cpython-311/src')
# sys.path.append('/workspace/trans_onnx/project/new_project/nvppnnx/python/build/temp.linux-x86_64-cpython-311/src')
sys.path.append('/workspace/trans_onnx/project/new_project/ncnn/tools/pnnx/python/build/temp.linux-x86_64-cpython-311/src')
else:
assert False, "noly support win and linux"
import ptx
Expand All @@ -38,19 +39,19 @@
# def models
#-------------------------------------------

class takeModel(nn.Module):
class IndexModel(nn.Module):
def __init__(self,):
super(takeModel, self).__init__()

def forward(self, v_0):
# indices = torch.tensor([0,2], dtype=torch.long)
indices= torch.tensor([
super(IndexModel, self).__init__()
self.indices= torch.tensor([
[0, 1],
[1, 2],
], dtype=torch.long)
def forward(self, v_0):
# indices = torch.tensor([0,2], dtype=torch.long)

# gathered = torch.gather(v_0, dim=3, index=indices)
# gathered = v_0[:,indices]
gathered = v_0[indices,:]
gathered = v_0[:,self.indices,self.indices,:]
# gathered = v_0[self.indices,:]
return gathered


Expand Down Expand Up @@ -114,22 +115,28 @@ def export(model_name: str, net: nn.Module, input_shape, export_onnx: bool):
if __name__ == "__main__":

net_map = {
"take": takeModel,
"index2": IndexModel,
"stack":stackModel
}

model_name = 'stack'
model_name = 'index2'
if model_name in net_map:
net = net_map[model_name]()
else:
assert False, 'not found model_name: {} in net_map'.format(model_name)

input_shape = [[1,3, 224],[1,3,224]]
# input_shape = [[1,3, 224],[1,3,224]]
# v_0 = torch.tensor( [
# [1.0, 1.2],
# [2.3, 3.4],
# [4.5, 5.7],
# ])
# input_shape = [v_0]
v_0 = torch.tensor( [
[1.0, 1.2, 1.3],
[2.3, 3.4, 1.4],
[4.5, 5.7, 1.8],
])
v_0 = torch.rand([1,3,4,4], dtype= float)
input_shape = [v_0]
export_onnx = True
export(model_name, net, input_shape,export_onnx)
14 changes: 9 additions & 5 deletions tools/pnnx/tools/pass_level7_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,11 +300,15 @@ def run_pass(op,operator_dict, operand_dict):
# input_shape_str = '[1,197,9,64],[1,197,9,64],[1,197,9,64]'

# multi stack
example_name = 'multi_stack'
pt_path_str = 'D:/project/programs/ncnn_project/ncnn/tools/pnnx/model_zoo/stack/stack.pt'
pt_path_str = '/workspace/trans_onnx/project/new_project/ncnn/tools/pnnx/model_zoo/stack/stack.pt'
input_shape_str = '[1,3,224],[1,3,224]'

# example_name = 'multi_stack'
# pt_path_str = 'D:/project/programs/ncnn_project/ncnn/tools/pnnx/model_zoo/stack/stack.pt'
# pt_path_str = '/workspace/trans_onnx/project/new_project/ncnn/tools/pnnx/model_zoo/stack/stack.pt'
# input_shape_str = '[1,3,224],[1,3,224]'

# index
example_name = 'index'
pt_path_str = '/workspace/trans_onnx/project/new_project/ncnn/tools/pnnx/model_zoo/index/index.pt'
input_shape_str = '[3,2]'
# custom_op_path_str =
# infer_py_path =
# pass_level7_path = 'D:/project/programs/ncnn_project/ncnn/tools/pnnx/pass_level7'
Expand Down

0 comments on commit 8073429

Please sign in to comment.