Skip to content

Commit

Permalink
add pass level7 dict
Browse files Browse the repository at this point in the history
  • Loading branch information
sen.li committed May 16, 2024
1 parent 4539597 commit 828c823
Show file tree
Hide file tree
Showing 2 changed files with 345 additions and 7 deletions.
337 changes: 337 additions & 0 deletions tools/pnnx/tools/pass_level7_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,337 @@
from serializer import *
import numpy as np
import struct
import copy
def ParseParams(op, customOp_attrs = None):
"""Convert a list of AttributeProto to a dict, with names as keys."""
params_data = {}
params = op.params

#parse parms
#0=null 1=b 2=i 3=f 4=s 5=ai 6=af 7=as 8=others
for name, param in params.items():
param_type = param.type
if param_type == 0:
params_data[name] = None
elif param_type == 1:
params_data[name] = param.b
elif param_type == 2:
params_data[name] = param.i
elif param_type == 3:
params_data[name] = param.f
elif param_type == 4:
params_data[name] = param.s
elif param_type == 5:
params_data[name] = param.a_i
elif param_type == 6:
params_data[name] = param.a_f
elif param_type == 7:
params_data[name] = param.a_s
else:
raise Exception("params type [{}] do not supported!".format(param_type))
if 'padding' in params_data and params_data['padding'] == 'same':
params_data['padding'] = (np.array(params_data['kernel_size']) // 2).tolist()
if customOp_attrs == None:
return params_data
else:
update_params_data = {}
for op_name, custom_op_name in customOp_attrs.items():
if custom_op_name not in params_data:
raise Exception("please check customOp_attrs {}:{}!".format(op_name, custom_op_name))
update_params_data[op_name] = params_data[custom_op_name]
return update_params_data


def ParseAttrs(op):
attrs_data = {}
#parse attrs
attrs = op.attrs
for name,attr in attrs.items():
sub_dict = {}
sub_dict['shape'] = attr.shape
if attr.type == 1:
dtype = 'float32'
elif attr.type == 2:
dtype = 'float64'
elif attr.type == 3:
dtype = 'float16'
elif attr.type == 4:
dtype = 'int32'
elif attr.type == 5:
dtype = 'int64'
elif attr.type == 6:
dtype = 'int16'
elif attr.type == 7:
dtype = 'int8'
elif attr.type == 8:
dtype = 'uint8'
elif attr.type == 9:
dtype = 'bool'
else:
raise Exception("attr.type [{}] do not supported!".format(attr.type))
if hasattr(attr,'b_data'):
sub_dict['data'] = np.frombuffer(attr.b_data, dtype=dtype)
else:
sub_dict['data'] = attr.data
attrs_data[name] = sub_dict
return attrs_data

def load_module(module_path):
spec = importlib.util.spec_from_file_location("module", module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module


def debug_op(operators):
for op in operators:
print("----------cur op name :{}-------------".format(op.name))
op_in = [ i.name for i in op.inputs]
op_out = [ o.name for o in op.outputs]
print("inputs :{}".format(op_in))
print("outputs :{}".format(op_out))


def debug_operand(operands):
for tensor in operands:
print("----------cur tensor name :{}-------------".format(tensor.name))
producer = tensor.producer.name
consumers = [ o.name for o in tensor.consumers]
print("producer :{}".format(producer))
print("consumers :{}".format(consumers))


def get_src_node_info(op):
input_names, input_shapes, input_datas = [], [], []
attr_input_names, attr_input_datas = [], []
inOperands = op.inputs
for operand in inOperands:
if operand.producer.type == 'pnnx.Attribute':
attrs_params = ParseAttrs(operand.producer)
operand_dict = attrs_params["data"]
attr_input_names.append(operand.name)
attr_input_datas.append(operand_dict['data'].reshape(operand_dict['shape']))
else:
input_names.append(operand.name)
input_shapes.append(operand.shape)
input_datas.append(torch.rand(operand.shape, dtype = torch.float))


outOperands = op.outputs
output_names = [out_operand.name for out_operand in outOperands]
return input_names, input_shapes, input_datas, attr_input_names, attr_input_datas, output_names

def trans_list_to_dict(operator, operand, update_name = False, cur_op_name = ''):

def update_operand_name(operands,operator_update_name_dict,operand_update_name_dict):
for operand in operands:
if operand.name in operand_update_name_dict and operand_update_name_dict[operand.name] == False:
operand_update_name_dict[operand.name] = True
operand.name = cur_op_name + '_tensor_' + operand.name
# get producer consumers
producer = operand.producer
consumers = operand.consumers
update_operator_name([producer],operator_update_name_dict,operand_update_name_dict)
update_operator_name(consumers,operator_update_name_dict,operand_update_name_dict)

def update_operator_name(operator, operator_update_name_dict, operand_update_name_dict):
for index, op in enumerate(operator):
if op.name in operator_update_name_dict and operator_update_name_dict[op.name] == False:
operator_update_name_dict[op.name] = True
operator[index].name = cur_op_name + '_expand_' + operator[index].name
# get inputs outputs
inputs_operand = op.inputs
outputs_operand = op.outputs
update_operand_name(inputs_operand,operator_update_name_dict,operand_update_name_dict)
update_operand_name(outputs_operand,operator_update_name_dict,operand_update_name_dict)

if update_name:
operator_update_name_dict = {}
for op in operator:
operator_update_name_dict[op.name] =False

operand_update_name_dict = {}
for tensor in operand:
operand_update_name_dict[tensor.name] = False
new_operator_update_name_dict = operator_update_name_dict.copy()
new_operand_update_name_dict = operand_update_name_dict.copy()
update_operator_name(operator,operator_update_name_dict,operand_update_name_dict)
for op in operator:
op.name = cur_op_name + '_expand_' + op.name
update_operand_name(operand,new_operator_update_name_dict,new_operand_update_name_dict)

operator_dict = {op.name: op for op in operator}
operand_dict = {tensor.name: tensor for tensor in operand}
return operator_dict, operand_dict





def trans_dict_to_list(operator_dict, operand_dict):
operator = list(operator_dict.values())
operand = list(operand_dict.values())
return operator, operand


def get_pre_node_name(operand_dict, operand_names):
pre_node_name = []
for input_name in operand_names:
pre_node_name.append(operand_dict[input_name].producer.name)

return pre_node_name


if __name__ == "__main__":


parser = PnnxParser()
pt_path_str = 'D:/project/programs/my_project/tests/test_python/test_op/model_zoo2/stack_16/stack_16.pt'
input_shape_str = '[1,3,224,224],[1,3,224,224]'
# custom_op_path_str =
# infer_py_path =
pass_level7_path = 'D:/project/programs/ncnn_project/ncnn/tools/pnnx/pass_level7'
# gen pnnx model
operators, operands, input_ops, output_ops = parser.getNvpPnnxModel(pt_path_str, input_shape_str)
# trans list to dict for pass
operator_dict, operand_dict = trans_list_to_dict(operators, operands)


pass_level7_tmp_output_path = 'D:/project/programs/ncnn_project/ncnn/tools/pnnx/output/tmp'
# loop all pass level7
all_pass_files = os.listdir(pass_level7_path)
all_pass_files = [pass_file for pass_file in all_pass_files if pass_file not in ['__init__.py'] and not os.path.isdir(os.path.join(pass_level7_path, pass_file))]
for pass_file in all_pass_files:
pass_name, _ = os.path.splitext(pass_file)
print("run pass:{}".format(pass_name))
passMod = load_module(os.path.join(pass_level7_path, pass_file))
op_type = getattr(passMod, 'op_type')
export_pt = getattr(passMod, 'export_torchscript')
# loop all op
while True:
matched = False
for op_name, op in operator_dict.items():
if op.type == op_type:
matched = True

# -------run pass------

# 1. export pt

# get params and attr_dict
params_dict = ParseParams(op)
attrs_dict = ParseAttrs(op)
# update attrs_dict
for attrs_key, attrs_value in attrs_dict.items():
attrs_data = attrs_value['data']
attrs_shape = attrs_value['shape']
params_dict[attrs_key] = attrs_data.reshape(attrs_shape)


# get src node info
input_names, input_shapes, input_datas, \
attr_input_names, attr_input_datas,\
output_names = get_src_node_info(op)

# export pt
all_params_dict = params_dict.copy()
all_params_dict['v_0'] = input_datas
all_params_dict['save_dir'] = pass_level7_tmp_output_path
all_params_dict['op_name'] = op_name
all_params_dict['attr_data'] = [torch.from_numpy(attr_input_data) for attr_input_data in attr_input_datas]
export_pt(**all_params_dict)

pass_pt_path = os.path.join(pass_level7_tmp_output_path, op_name + '.pt').replace('\\','/')
pass_input_shape_str = ','.join([str(inner_list) for inner_list in input_shapes])
pass_input_shape_str.replace(' ','')
# 2. export pnnx
cur_parser = PnnxParser()
pass_operators, pass_operands, pass_input_ops, pass_output_ops = cur_parser.getNvpPnnxModel(pass_pt_path, pass_input_shape_str)
pass_operators_dict, pass_operands_dict = trans_list_to_dict(pass_operators, pass_operands, True, op_name)

attr_input_node_name = get_pre_node_name(operand_dict, attr_input_names)
del_op_names = [op.name] + attr_input_node_name
for del_op_name in del_op_names:
operator_dict.pop(del_op_name)

# insert pass op
input_index = 0
output_index = 0
for cur_pass_op_name, cur_pass_op in pass_operators_dict.items():
# for cur_pass_op in pass_operators:
if cur_pass_op.type == 'pnnx.Input':
# get src input operand
src_input_operand_name = input_names[input_index]
src_input_operand = operand_dict[src_input_operand_name]
# get src input node name
src_input_node_name = src_input_operand.producer.name
# get dst ops
cur_pass_input_operand = cur_pass_op.outputs[0]
cur_dst_ops = cur_pass_input_operand.consumers

# src_input_operand connect new node
src_input_operand.consumers = [ consumers for consumers in src_input_operand.consumers if consumers.name != op_name ] + cur_dst_ops
for dst_op in cur_dst_ops:
dsp_op_name = dst_op.name
# pass_operators_dict[dsp_op_name].inputs =
pass_operators_dict[dsp_op_name].inputs = [ src_input_operand if d_input.name == cur_pass_input_operand.name else d_input for d_input in pass_operators_dict[dsp_op_name].inputs]

# src_input node connect new node
src_input_node = operator_dict[src_input_node_name]
for src_input_node_out in src_input_node.outputs:
src_input_node_out.consumers = [ out_cons for out_cons in src_input_node_out.consumers if out_cons.name != op_name] + cur_dst_ops
#
input_index += 1


elif cur_pass_op.type == 'pnnx.Output':
src_output_name = output_names[output_index]
src_output_operand = operand_dict[src_output_name]

dst_output_op = cur_pass_op.inputs[0].producer
src_output_operand.producer = dst_output_op
dst_output_op_name = dst_output_op.name
pass_operators_dict[dst_output_op_name].outputs = [src_output_operand]

# sink node connect new node
src_output_node_names = [ con.name for con in src_output_operand.consumers]
for src_output_node_name in src_output_node_names:
for input_operand in operator_dict[src_output_node_name].inputs:
if input_operand.producer.name == op_name:
input_operand.producer.name = dst_output_op.name
output_index += 1
else:

operator_dict[cur_pass_op.name] = cur_pass_op

#delect src attr operands
for attr_input_name in attr_input_names:
operand_dict.pop(attr_input_name)

# insert pass operand
for cur_pass_operand in pass_operands:
if cur_pass_operand.producer.type != 'pnnx.Input' and cur_pass_operand.consumers[0].type != 'pnnx.Output':

operand_dict[cur_pass_operand.name] = cur_pass_operand
# debug info
print('finish pass {} in {}'.format(pass_name, op_name))
break

if not matched:
break
operators, operands = trans_dict_to_list(operator_dict, operand_dict)
debug_op(operators)
debug_operand(operands)













15 changes: 8 additions & 7 deletions tools/pnnx/tools/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
sys.path.append('D:/project/programs/ncnn_project/ncnn/tools/pnnx/python/build/lib.win-amd64-cpython-38/pnnx')
# sys.path.append('/workspace/trans_onnx/project/new_project/ncnn/tools/pnnx/python/build/temp.linux-x86_64-cpython-311/src')
import ptx
graph = ptx.PnnxGraph()

except ImportError as e:
sys.exit(str(e))

Expand Down Expand Up @@ -61,6 +61,7 @@ def __init__(self,):
as: list(str) type==7
"""
self.graph = ptx.PnnxGraph()

def LoadModel(self, params_path: str, bin_path: str):
"""
Expand All @@ -75,12 +76,12 @@ def LoadModel(self, params_path: str, bin_path: str):
input_ops list(Operator)
output_ops list(Operator)
"""
a = graph.loadModel(params_path,bin_path)
a = self.graph.loadModel(params_path,bin_path)
assert(a is True, "please check your you input path")
operators = graph.getOperators()
operands = graph.getOperands()
input_ops = graph.getInputOps()
output_ops = graph.getOutputOps()
operators = self.graph.getOperators()
operands = self.graph.getOperands()
input_ops = self.graph.getInputOps()
output_ops = self.graph.getOutputOps()

return operators, operands, input_ops, output_ops

Expand All @@ -100,7 +101,7 @@ def getNvpPnnxModel(self, pt_path_str: str, input_shape_str: str, custom_op_path
input_ops list(Operator)
output_ops list(Operator)
"""
result = graph.getNvpPnnxModel(pt_path_str, input_shape_str, custom_op_path_str, infer_py_path)
result = self.graph.getNvpPnnxModel(pt_path_str, input_shape_str, custom_op_path_str, infer_py_path)
assert(result, "get pnnx model failed")
params_path = pt_path_str.replace('.pt','.pnnx.param')
bin_path = pt_path_str.replace('.pt','.pnnx.bin')
Expand Down

0 comments on commit 828c823

Please sign in to comment.