From 676db7051c60922d2fe020c2e2853cb0d0d30f09 Mon Sep 17 00:00:00 2001 From: "sen.li" Date: Thu, 4 Jul 2024 11:50:10 +0800 Subject: [PATCH] update pass_level7 --- tools/pnnx/pass_level7/PixelShufflePass.py | 49 ++++++++++++++++++++ tools/pnnx/tools/export.py | 45 +++++++++++++++--- tools/pnnx/tools/gen_pass_level7_template.py | 17 +++++-- tools/pnnx/tools/pass_level7_dict.py | 2 + 4 files changed, 104 insertions(+), 9 deletions(-) create mode 100644 tools/pnnx/pass_level7/PixelShufflePass.py diff --git a/tools/pnnx/pass_level7/PixelShufflePass.py b/tools/pnnx/pass_level7/PixelShufflePass.py new file mode 100644 index 000000000000..3f403712cbc5 --- /dev/null +++ b/tools/pnnx/pass_level7/PixelShufflePass.py @@ -0,0 +1,49 @@ +import os +import torch +import torch.nn as nn +import torch.nn.functional as F + +op_type = 'nn.PixelShuffle' + +class Model(nn.Module): + def __init__(self, upscale_factor): + super(Model, self).__init__() + self.upscale_factor = upscale_factor + pass + + def forward(self, *v_0): + v_1 = v_0[0] + batch_size, channels, in_height, in_width = v_1.size() + channels //= (self.upscale_factor ** 2) + out_height = in_height * self.upscale_factor + out_width = in_width * self.upscale_factor + + shuffled = v_1.view(batch_size, channels, self.upscale_factor, self.upscale_factor, in_height, in_width) + + shuffled = shuffled.permute(0, 1, 4, 2, 5, 3) + + output = shuffled.reshape(batch_size, channels, out_height, out_width) + return output + + + +def export_torchscript(upscale_factor, v_0, save_dir, op_name, attr_data = None, input_shapes = None): + net = Model(upscale_factor) + net.eval() + mod = torch.jit.trace(net, v_0) + pt_path = os.path.join(save_dir, op_name + '.pt').replace('\\','/') + mod.save(pt_path) + +def check_pass(): + v_0 = torch.rand(1,64,8,8, dtype = torch.float) + #finish your check pass code + model = Model(2) + model.eval() + o1 = model(v_0) + p = nn.PixelShuffle(2) + o2 = p(v_0) + print(o1.shape) + print(o1==o2) + +if __name__ == "__main__": + check_pass() diff --git a/tools/pnnx/tools/export.py b/tools/pnnx/tools/export.py index ef2c8aaba075..eec8ec678a4d 100644 --- a/tools/pnnx/tools/export.py +++ b/tools/pnnx/tools/export.py @@ -29,7 +29,7 @@ if platform.system() == "Windows": - save_path = 'D:/project/programs/ncnn_project/ncnn/tools/pnnx/model_zoo' + save_path = r'D:\project\programs\ncnn_project\sub_model\ncnn\tools\pnnx\model_zoo' elif platform.system() == "Linux": save_path = '/workspace/trans_onnx/project/new_project/ncnn/tools/pnnx/model_zoo' else: @@ -133,6 +133,23 @@ def forward(self, x, y): for i in range(int(y)): x = x + y return x + +class PixelShuffle(torch.nn.Module): + def __init__(self,): + super(PixelShuffle, self).__init__() + self.p = nn.PixelShuffle(upscale_factor=2) + def forward(self, x): + # y = self.p(x) + y = F.pixel_shuffle(x,upscale_factor=2) + return y + +class MultiHeadAttention(torch.nn.Module): + def __init__(self,): + super(MultiHeadAttention, self).__init__() + self.m = nn.MultiheadAttention(add_bias_kv=False, add_zero_attn=False, batch_first=False, bias=True, embed_dim=512, kdim=512, num_heads=8, vdim=512) + def forward(self, v_122): + y,_ = self.m(v_122, v_122, v_122, need_weights=False) + return y def export(model_name: str, net: Union[nn.Module, str], input_shape, export_onnx: bool): if isinstance(input_shape, list): @@ -190,10 +207,12 @@ def export(model_name: str, net: Union[nn.Module, str], input_shape, export_onnx "reshape_as": reshape_as_Model, "unfold":unfold_Model, "NMS":NMS, - "Script1":Script1 + "Script1":Script1, + "PixelShuffle":PixelShuffle, + "MultiHeadAttention":MultiHeadAttention } - model_name = 'Script1' + model_name = 'PixelShuffle' if model_name in net_map: net = net_map[model_name]() else: @@ -232,9 +251,23 @@ def export(model_name: str, net: Union[nn.Module, str], input_shape, export_onnx # input_shape = [[4,4],[4]] # Script1 - i1 = torch.ones([5,5]) - i2 = torch.ones(1, dtype=torch.long) - input_shape = [i1,i2] + # i1 = torch.ones([5,5]) + # i2 = torch.ones(1, dtype=torch.long) + # input_shape = [i1,i2] + # export(model_name, net, input_shape, export_onnx) + + #PixelShuffle + # input_shape = [torch.randn([1,64,8,8])] + # export(model_name, net, input_shape, export_onnx) + + #multiHeadAttention + # input_shape = [torch.randn([256,1,512])] + # export(model_name, net, input_shape, export_onnx) + + # model_name = 'fovea' + net = r'D:\project\programs\ncnn_project\sub_model\ncnn\tools\pnnx\model_zoo\fovea\end2end.pt' + input_shape = [[1,3,800,1216]] + export_onnx = False export(model_name, net, input_shape, export_onnx) # import pnnx # pnnx.export diff --git a/tools/pnnx/tools/gen_pass_level7_template.py b/tools/pnnx/tools/gen_pass_level7_template.py index 2a97787525c6..04d1de890c18 100644 --- a/tools/pnnx/tools/gen_pass_level7_template.py +++ b/tools/pnnx/tools/gen_pass_level7_template.py @@ -25,7 +25,7 @@ def gen_pass_level7_template(ops, output_path, pass_name): params = cur_op.params attribute = cur_op.attrs init_params_name = list(params.keys()) + list(attribute.keys()) - init_params_name.append('input_shapes') + # init_params_name.append('input_shapes') op_type = cur_op.type output_py_path = os.path.join(output_path, pass_name + '.py') @@ -38,7 +38,7 @@ def gen_pass_level7_template(ops, output_path, pass_name): f.write("\n") #op_type - f.write("op_type = '{}'\n ".format(op_type)) + f.write("op_type = ['{}']\n ".format(op_type)) f.write("\n") #define model @@ -107,11 +107,22 @@ def gen_pass_level7_template(ops, output_path, pass_name): input_shape_str = '[1,3,9,9]' pass_name = 'UnfoldPass_new' + # PixelShuffle + pt_path_str = r'D:\project\programs\ncnn_project\sub_model\ncnn\tools\pnnx\model_zoo\PixelShuffle\PixelShuffle.pt' + input_shape_str = '[1,64,8,8]' + pass_name = 'PixelShuffle' + + + # MultiHeadAttention + pt_path_str = r'D:\project\programs\ncnn_project\sub_model\ncnn\tools\pnnx\model_zoo\MultiHeadAttention\MultiHeadAttention.pt' + input_shape_str = '[256,1,512]' + pass_name = 'MultiHeadAttention' + # custom_op_path_str = # infer_py_path = # gen pnnx model if platform.system() == "Windows": - output_path = 'D:/project/programs/ncnn_project/ncnn/tools/pnnx/pass_level7/template' + output_path = r'D:\project\programs\ncnn_project\sub_model\ncnn\tools\pnnx\pass_level7\template' elif platform.system() == "Linux": output_path = '/workspace/trans_onnx/project/new_project/ncnn/tools/pnnx/pass_level7/template' else: diff --git a/tools/pnnx/tools/pass_level7_dict.py b/tools/pnnx/tools/pass_level7_dict.py index 6aefb452ad1a..159b226729ae 100644 --- a/tools/pnnx/tools/pass_level7_dict.py +++ b/tools/pnnx/tools/pass_level7_dict.py @@ -10,6 +10,7 @@ def ParseParams(op, customOp_attrs = None): #parse parms #0=null 1=b 2=i 3=f 4=s 5=ai 6=af 7=as 8=others for name, param in params.items(): + name = name.replace('.','_') param_type = param.type if param_type == 0: params_data[name] = None @@ -47,6 +48,7 @@ def ParseAttrs(op): #parse attrs attrs = op.attrs for name,attr in attrs.items(): + name = name.replace('.','_') sub_dict = {} sub_dict['shape'] = attr.shape if attr.type == 1: