Skip to content

Commit

Permalink
fix bug of Tensor.index with two inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Jun 5, 2024
1 parent e269508 commit ca44aa6
Show file tree
Hide file tree
Showing 10 changed files with 147 additions and 47 deletions.
3 changes: 3 additions & 0 deletions tools/pnnx/Releasenotes
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,6 @@ dev.1.0.14.20240531
dev.1.0.15.20240603
1. Support parse Tensor.reshape_as
2. Add trans_ReshapeAs2Reshape pass

dev.1.0.16.20240605
1. fix bug of Tensor.index with two inputs
2 changes: 1 addition & 1 deletion tools/pnnx/pass_level7/ScaledDotProductAttenPass.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def forward(self, *v_0):
v_8 = torch.matmul(input=v_7, other=v_3)
return v_8

def export_torchscript(attn_mask, dropout_p, is_causal, v_0, save_dir, op_name, attr_data = None):
def export_torchscript(attn_mask, dropout_p, is_causal, v_0, save_dir, op_name, attr_data = None, input_shapes = None):
net = Model(attn_mask, dropout_p, is_causal)
net.eval()
mod = torch.jit.trace(net, v_0)
Expand Down
4 changes: 2 additions & 2 deletions tools/pnnx/pass_level7/Stack2UnsqueezewithCatPass.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def forward(self, *v_0):
tensor_list.append(torch.unsqueeze(vv, self.dim))
v_1 = torch.cat(tensor_list, self.dim)
return v_1

def export_torchscript(dim, v_0, save_dir, op_name, attr_data = []):
def export_torchscript(dim, v_0, save_dir, op_name, attr_data = [], input_shapes = None):
net = Model(dim, attr_data)
net.eval()
mod = torch.jit.trace(net, v_0)
Expand Down
8 changes: 4 additions & 4 deletions tools/pnnx/pass_level7/UnfoldPass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
class Model(nn.Module):
def __init__(self,dilation, kernel_size, padding, stride, input_shapes):
super(Model, self).__init__()
assert len(input_shapes) == 1, 'the num of nn.Unfold input must be equal 1'
assert len(input_shapes) == 1, 'the num of nn.Unfold input must be 1'
input_shape = input_shapes[0]
assert len(input_shape) == 4, 'the dim of nn.Unfold input must be equal 4'
assert len(input_shape) == 4, 'the dim of nn.Unfold input numst be 1'
self.b, c, ih, iw = input_shape
ih += 2 * padding[0]
iw += 2 * padding[1]
Expand Down Expand Up @@ -47,8 +47,8 @@ def forward(self, *v_0):
return v_4


def export_torchscript(dilation, kernel_size, padding, stride, v_0, save_dir, op_name, attr_data = None):
net = Model(dilation, kernel_size, padding, stride)
def export_torchscript(dilation, kernel_size, padding, stride, v_0, save_dir, op_name, attr_data = None, input_shapes = None):
net = Model(dilation, kernel_size, padding, stride, input_shapes)
net.eval()
mod = torch.jit.trace(net, v_0)
pt_path = os.path.join(save_dir, op_name + '.pt').replace('\\','/')
Expand Down
51 changes: 30 additions & 21 deletions tools/pnnx/src/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1431,6 +1431,10 @@ static std::string make_index_expression(const Operator* op)
index_expr = index_expr.substr(5);
indices_index++;
}
size_t pos = 0;
if ((pos = index_expr.find("@")) != std::string::npos) {
index_expr.replace(pos, 1, "v_");
}
for(int i = 0; i < shape.size(); i++)
{
if ( i == indices_index)
Expand Down Expand Up @@ -1819,16 +1823,19 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath)
else if (op->type == "Tensor.index")
{
// index expr
if (op->inputs.size() == 2)
{
std::string expanded_expr = expand_expression(op->inputs[1]->producer);
fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), expanded_expr.c_str());
}
else
{
std::string index_expr = make_index_expression(op);
fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), index_expr.c_str());
}
// if (op->inputs.size() == 2)
// {
// std::string expanded_expr = expand_expression(op->inputs[1]->producer);
// fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), expanded_expr.c_str());
// }
// else
// {
// std::string index_expr = make_index_expression(op);
// fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), index_expr.c_str());
// }
std::string index_expr = make_index_expression(op);
fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), index_expr.c_str());

}
else if (op->type == "Tensor.expand")
{
Expand Down Expand Up @@ -3135,7 +3142,7 @@ int Graph::python_infer(const std::string& pypath, const std::string& binpath,

fprintf(pyfp, "\n");

// get input_shape and input_type add by senli[pnnx_infer]
// get input_shape and input_type add by senli[pnnx_infer]
{
// get shape and type of the input op
std::vector<std::vector<int>> input_shapes;
Expand Down Expand Up @@ -3268,16 +3275,18 @@ int Graph::python_infer(const std::string& pypath, const std::string& binpath,
else if (op->type == "Tensor.index")
{
// index expr
if (op->inputs.size() == 2)
{
std::string expanded_expr = expand_expression(op->inputs[1]->producer);
fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), expanded_expr.c_str());
}
else
{
std::string index_expr = make_index_expression(op);
fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), index_expr.c_str());
}
// if (op->inputs.size() == 2)
// {
// std::string expanded_expr = expand_expression(op->inputs[1]->producer);
// fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), expanded_expr.c_str());
// }
// else
// {
// std::string index_expr = make_index_expression(op);
// fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), index_expr.c_str());
// }
std::string index_expr = make_index_expression(op);
fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), index_expr.c_str());
}
else if (op->type == "Tensor.expand")
{
Expand Down
45 changes: 43 additions & 2 deletions tools/pnnx/src/pass_level6/trans_expression2TupleConstruct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,52 @@ void trans_expression2TupleConstruct(Graph& graph)
{
Parameter param = op->params["expr"];
std::string expr = param.s;
// printf("op_name:%s\n",op->name.c_str());
if (expr.front() == '[' && expr.back() == ']')
{
matched = true;
op->type = "prim::TupleConstruct";
op->params.clear();
std::vector<Operand*> outputs = op->outputs;
bool sink_node_is_index = false;
if(outputs[0]->consumers[0]->type == "Tensor.index")
{
sink_node_is_index = true;
}

if (sink_node_is_index)
{
// update expr
std::string out_operand_name = outputs[0]->name;
size_t pos = 0;
if((pos = expr.find("0")) != std::string::npos)
{
expr.replace(pos, 1, out_operand_name);
}
outputs[0]->consumers[0]->params["expr"] = expr;
Operand* input = op->inputs[0];
Operator* pre_node = input->producer;
pre_node->outputs.clear();
for (auto& single_out : outputs)
{
single_out->producer = pre_node;
pre_node->outputs.push_back(single_out);
}
input->producer = 0;
input->consumers.clear();
graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), input));
delete input;

op->inputs.clear();
op->outputs.clear();

graph.ops.erase(graph.ops.begin() + i);
delete op;
}
else
{
op->type = "prim::TupleConstruct";
op->params.clear();
}

break;
}
}
Expand Down
2 changes: 1 addition & 1 deletion tools/pnnx/src/py_proj.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
// #include <torch/extension.h>
#define STRINGIFY(x) #x
#define MACRO_STRINGIFY(x) STRINGIFY(x)
#define MYLIBRARY_VERSION "dev.1.0.15.20240603"
#define MYLIBRARY_VERSION "dev.1.0.16.20240605"
using namespace pnnx_graph;
using namespace pnnx_ir;
namespace py = pybind11;
Expand Down
53 changes: 44 additions & 9 deletions tools/pnnx/tools/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,25 @@ def __init__(self,):
def forward(self, v_0):
one_hot = F.one_hot(v_0, num_classes=4)
return one_hot

class reshape_as_Model(nn.Module):
def __init__(self,):
super(reshape_as_Model, self).__init__()

def forward(self, x, y):
output = x.reshape_as(y)
return output

class unfold_Model(nn.Module):
def __init__(self,):
super(unfold_Model, self).__init__()
self.unfold = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(0,1),dilation=(2,2))
def forward(self, x):
output = self.unfold(x)
return output


def export(model_name: str, net: nn.Module, input_shape, export_onnx: bool):
def export(model_name: str, net: Union[nn.Module, str], input_shape, export_onnx: bool):
if isinstance(input_shape, list):
input_tensor_list = []
input_shape_str_list = []
Expand Down Expand Up @@ -129,10 +145,12 @@ def export(model_name: str, net: nn.Module, input_shape, export_onnx: bool):

else:
input_tensor_list = tuple(input_tensor_list)

mod = torch.jit.trace(net, input_tensor_list)
pt_path = os.path.join(save_dir, model_name + '.pt').replace('\\','/')
mod.save(pt_path)
if isinstance(net, str):
pt_path = net
else:
mod = torch.jit.trace(net, input_tensor_list)
pt_path = os.path.join(save_dir, model_name + '.pt').replace('\\','/')
mod.save(pt_path)
# export pnnx
result = graph.getNvpPnnxModel(pt_path, input_shape_str, 'None', 'None')
assert result == 1, 'failed to export pnnx'
Expand All @@ -152,9 +170,11 @@ def export(model_name: str, net: nn.Module, input_shape, export_onnx: bool):
"index2": IndexModel,
"stack":stackModel,
"oneHot":oneHotModel,
"reshape_as": reshape_as_Model,
"unfold":unfold_Model
}

model_name = 'oneHot'
model_name = 'unfold'
if model_name in net_map:
net = net_map[model_name]()
else:
Expand All @@ -172,11 +192,26 @@ def export(model_name: str, net: nn.Module, input_shape, export_onnx: bool):
# [4.5, 5.7, 1.8],
# ])
# v_0 = torch.rand([1,3,4,4], dtype= float)
v_0 = torch.tensor([0, 2, 1, 3])
input_shape = [v_0]
# v_0 = torch.tensor([0, 2, 1, 3])
# input_shape = [v_0]
# input_shape = [[1,3, 224]]

# v_0 = torch.randn(2, 3) # 第一个输入张量
# v_1 = torch.randn(3, 2) # 第二个输入张量

# input_shape = [v_0,v_1]
export_onnx = True
export(model_name, net, input_shape,export_onnx)
#-------------------------------------------------------
# model_name = 'pvig'
# net = '/workspace/trans_onnx/project/new_project/ncnn/tools/pnnx/model_zoo/pvig/model.pt'
# input_shape = [[1,3,224,224]]
# export_onnx = False
# ----------------------------

# unfold
input_shape = [[1,3,9,9]]

export(model_name, net, input_shape, export_onnx)
# import pnnx
# pnnx.export

Expand Down
17 changes: 13 additions & 4 deletions tools/pnnx/tools/gen_pass_level7_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def gen_pass_level7_template(ops, output_path, pass_name):
params = cur_op.params
attribute = cur_op.attrs
init_params_name = list(params.keys()) + list(attribute.keys())
init_params_name.append('input_shapes')
op_type = cur_op.type

output_py_path = os.path.join(output_path, pass_name + '.py')
Expand Down Expand Up @@ -60,6 +61,7 @@ def gen_pass_level7_template(ops, output_path, pass_name):
export_name.append('save_dir')
export_name.append('op_name')
export_name.append('attr_data = None')
export_name.append('input_shapes = None')
export_params_name_str = ', '.join(export_name)
f.write("def export_torchscript(" + export_params_name_str + "):\n")
f.write("\tnet = Model(" + init_params_name_str + ")\n")
Expand Down Expand Up @@ -101,14 +103,21 @@ def gen_pass_level7_template(ops, output_path, pass_name):
# pass_name = 'ScaledDotProductAttenPass'

# unfold
pt_path_str = 'D:/project/programs/my_project/tests/test_python/test_op/model_zoo2/unfold/unfold.pt'
input_shape_str = '[1,1,4,4]'
pass_name = 'UnfoldPass'
pt_path_str = '/workspace/trans_onnx/project/new_project/ncnn/tools/pnnx/model_zoo/unfold/unfold.pt'
input_shape_str = '[1,3,9,9]'
pass_name = 'UnfoldPass_new'

# custom_op_path_str =
# infer_py_path =
# gen pnnx model
output_path = 'D:/project/programs/ncnn_project/ncnn/tools/pnnx/pass_level7/template'
if platform.system() == "Windows":
output_path = 'D:/project/programs/ncnn_project/ncnn/tools/pnnx/pass_level7/template'
elif platform.system() == "Linux":
output_path = '/workspace/trans_onnx/project/new_project/ncnn/tools/pnnx/pass_level7/template'
else:
assert False, "noly support win and linux"



operators, operands, input_ops, output_ops = parser.getNvpPnnxModel(pt_path_str, input_shape_str)

Expand Down
9 changes: 6 additions & 3 deletions tools/pnnx/tools/pass_level7_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,13 @@ def get_src_node_info(op):
input_names, input_shapes, input_datas = [], [], []
attr_input_names, attr_input_datas = [], []
inOperands = op.inputs
for operand in inOperands:
for index, operand in enumerate(inOperands):
if operand.producer.type == 'pnnx.Attribute':
attrs_params = ParseAttrs(operand.producer)
operand_dict = attrs_params["data"]
attr_input_names.append(operand.name)
attr_input_datas.append(operand_dict['data'].reshape(operand_dict['shape']))
# attr_input_datas.append(operand_dict['data'].reshape(operand_dict['shape']))
attr_input_datas[index] = torch.from_numpy(operand_dict['data'].reshape(operand_dict['shape']))
else:
input_names.append(operand.name)
input_shapes.append(operand.shape)
Expand Down Expand Up @@ -206,7 +207,9 @@ def run_pass(op,operator_dict, operand_dict):
all_params_dict['v_0'] = input_datas
all_params_dict['save_dir'] = pass_level7_tmp_output_path
all_params_dict['op_name'] = op_name
all_params_dict['attr_data'] = [torch.from_numpy(attr_input_data) for attr_input_data in attr_input_datas]
# all_params_dict['attr_data'] = [torch.from_numpy(attr_input_data) for attr_input_data in attr_input_datas]
all_params_dict['attr_data'] = attr_input_datas
all_params_dict['input_shapes'] = input_shapes
export_pt(**all_params_dict)

pass_pt_path = os.path.join(pass_level7_tmp_output_path, op_name + '.pt').replace('\\','/')
Expand Down

0 comments on commit ca44aa6

Please sign in to comment.