Skip to content

Commit

Permalink
update pass_level7
Browse files Browse the repository at this point in the history
  • Loading branch information
sen.li committed Jul 4, 2024
1 parent a040847 commit 676db70
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 9 deletions.
49 changes: 49 additions & 0 deletions tools/pnnx/pass_level7/PixelShufflePass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

op_type = 'nn.PixelShuffle'

class Model(nn.Module):
def __init__(self, upscale_factor):
super(Model, self).__init__()
self.upscale_factor = upscale_factor
pass

def forward(self, *v_0):
v_1 = v_0[0]
batch_size, channels, in_height, in_width = v_1.size()
channels //= (self.upscale_factor ** 2)
out_height = in_height * self.upscale_factor
out_width = in_width * self.upscale_factor

shuffled = v_1.view(batch_size, channels, self.upscale_factor, self.upscale_factor, in_height, in_width)

shuffled = shuffled.permute(0, 1, 4, 2, 5, 3)

output = shuffled.reshape(batch_size, channels, out_height, out_width)
return output



def export_torchscript(upscale_factor, v_0, save_dir, op_name, attr_data = None, input_shapes = None):
net = Model(upscale_factor)
net.eval()
mod = torch.jit.trace(net, v_0)
pt_path = os.path.join(save_dir, op_name + '.pt').replace('\\','/')
mod.save(pt_path)

def check_pass():
v_0 = torch.rand(1,64,8,8, dtype = torch.float)
#finish your check pass code
model = Model(2)
model.eval()
o1 = model(v_0)
p = nn.PixelShuffle(2)
o2 = p(v_0)
print(o1.shape)
print(o1==o2)

if __name__ == "__main__":
check_pass()
45 changes: 39 additions & 6 deletions tools/pnnx/tools/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@


if platform.system() == "Windows":
save_path = 'D:/project/programs/ncnn_project/ncnn/tools/pnnx/model_zoo'
save_path = r'D:\project\programs\ncnn_project\sub_model\ncnn\tools\pnnx\model_zoo'
elif platform.system() == "Linux":
save_path = '/workspace/trans_onnx/project/new_project/ncnn/tools/pnnx/model_zoo'
else:
Expand Down Expand Up @@ -133,6 +133,23 @@ def forward(self, x, y):
for i in range(int(y)):
x = x + y
return x

class PixelShuffle(torch.nn.Module):
def __init__(self,):
super(PixelShuffle, self).__init__()
self.p = nn.PixelShuffle(upscale_factor=2)
def forward(self, x):
# y = self.p(x)
y = F.pixel_shuffle(x,upscale_factor=2)
return y

class MultiHeadAttention(torch.nn.Module):
def __init__(self,):
super(MultiHeadAttention, self).__init__()
self.m = nn.MultiheadAttention(add_bias_kv=False, add_zero_attn=False, batch_first=False, bias=True, embed_dim=512, kdim=512, num_heads=8, vdim=512)
def forward(self, v_122):
y,_ = self.m(v_122, v_122, v_122, need_weights=False)
return y

def export(model_name: str, net: Union[nn.Module, str], input_shape, export_onnx: bool):
if isinstance(input_shape, list):
Expand Down Expand Up @@ -190,10 +207,12 @@ def export(model_name: str, net: Union[nn.Module, str], input_shape, export_onnx
"reshape_as": reshape_as_Model,
"unfold":unfold_Model,
"NMS":NMS,
"Script1":Script1
"Script1":Script1,
"PixelShuffle":PixelShuffle,
"MultiHeadAttention":MultiHeadAttention
}

model_name = 'Script1'
model_name = 'PixelShuffle'
if model_name in net_map:
net = net_map[model_name]()
else:
Expand Down Expand Up @@ -232,9 +251,23 @@ def export(model_name: str, net: Union[nn.Module, str], input_shape, export_onnx
# input_shape = [[4,4],[4]]

# Script1
i1 = torch.ones([5,5])
i2 = torch.ones(1, dtype=torch.long)
input_shape = [i1,i2]
# i1 = torch.ones([5,5])
# i2 = torch.ones(1, dtype=torch.long)
# input_shape = [i1,i2]
# export(model_name, net, input_shape, export_onnx)

#PixelShuffle
# input_shape = [torch.randn([1,64,8,8])]
# export(model_name, net, input_shape, export_onnx)

#multiHeadAttention
# input_shape = [torch.randn([256,1,512])]
# export(model_name, net, input_shape, export_onnx)

# model_name = 'fovea'
net = r'D:\project\programs\ncnn_project\sub_model\ncnn\tools\pnnx\model_zoo\fovea\end2end.pt'
input_shape = [[1,3,800,1216]]
export_onnx = False
export(model_name, net, input_shape, export_onnx)
# import pnnx
# pnnx.export
Expand Down
17 changes: 14 additions & 3 deletions tools/pnnx/tools/gen_pass_level7_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def gen_pass_level7_template(ops, output_path, pass_name):
params = cur_op.params
attribute = cur_op.attrs
init_params_name = list(params.keys()) + list(attribute.keys())
init_params_name.append('input_shapes')
# init_params_name.append('input_shapes')
op_type = cur_op.type

output_py_path = os.path.join(output_path, pass_name + '.py')
Expand All @@ -38,7 +38,7 @@ def gen_pass_level7_template(ops, output_path, pass_name):
f.write("\n")

#op_type
f.write("op_type = '{}'\n ".format(op_type))
f.write("op_type = ['{}']\n ".format(op_type))
f.write("\n")

#define model
Expand Down Expand Up @@ -107,11 +107,22 @@ def gen_pass_level7_template(ops, output_path, pass_name):
input_shape_str = '[1,3,9,9]'
pass_name = 'UnfoldPass_new'

# PixelShuffle
pt_path_str = r'D:\project\programs\ncnn_project\sub_model\ncnn\tools\pnnx\model_zoo\PixelShuffle\PixelShuffle.pt'
input_shape_str = '[1,64,8,8]'
pass_name = 'PixelShuffle'


# MultiHeadAttention
pt_path_str = r'D:\project\programs\ncnn_project\sub_model\ncnn\tools\pnnx\model_zoo\MultiHeadAttention\MultiHeadAttention.pt'
input_shape_str = '[256,1,512]'
pass_name = 'MultiHeadAttention'

# custom_op_path_str =
# infer_py_path =
# gen pnnx model
if platform.system() == "Windows":
output_path = 'D:/project/programs/ncnn_project/ncnn/tools/pnnx/pass_level7/template'
output_path = r'D:\project\programs\ncnn_project\sub_model\ncnn\tools\pnnx\pass_level7\template'
elif platform.system() == "Linux":
output_path = '/workspace/trans_onnx/project/new_project/ncnn/tools/pnnx/pass_level7/template'
else:
Expand Down
2 changes: 2 additions & 0 deletions tools/pnnx/tools/pass_level7_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def ParseParams(op, customOp_attrs = None):
#parse parms
#0=null 1=b 2=i 3=f 4=s 5=ai 6=af 7=as 8=others
for name, param in params.items():
name = name.replace('.','_')
param_type = param.type
if param_type == 0:
params_data[name] = None
Expand Down Expand Up @@ -47,6 +48,7 @@ def ParseAttrs(op):
#parse attrs
attrs = op.attrs
for name,attr in attrs.items():
name = name.replace('.','_')
sub_dict = {}
sub_dict['shape'] = attr.shape
if attr.type == 1:
Expand Down

0 comments on commit 676db70

Please sign in to comment.