Skip to content

Commit

Permalink
extract sub graph
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed May 13, 2024
1 parent a00235f commit 022d9c3
Showing 1 changed file with 306 additions and 0 deletions.
306 changes: 306 additions & 0 deletions tools/pnnx/seralizer/seralizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
import sys
from typing import List, Union, Optional
import argparse
import os
import shutil
import json
import importlib
import re
try:
import torch
import torchvision.models as models
import torch.nn as nn
import torch.nn.functional as F
# sys.path.append('/workspace/trans_onnx/project/new_project/nvppnnx/python/build/temp.linux-x86_64-cpython-311/src')
sys.path.append('/workspace/trans_onnx/project/new_project/ncnn/tools/pnnx/python/build/temp.linux-x86_64-cpython-311/src')
import ptx
graph = ptx.PnnxGraph()
except ImportError as e:
sys.exit(str(e))

def extract_content_between_parentheses(text):
pattern = r'\((.*?)\)'
matches = re.findall(pattern, text)
return matches

class PnnxParser():

def __init__(self,):
"""pnnx ir description
Operator:
intputs: list(Operand)
outputs: list(Operand)
type: str
name: str
inputnames: list(str)
params: dict{str:Parameter}
attrs: dict{str:Attribute}
Operand:
producer: Operator
consumers: list(Operator)
type: int // 0=null 1=f32 2=f64 3=f16 4=i32 5=i64 6=i16 7=i8 8=u8 9=bool 10=cp64 11=cp128 12=cp32
shape: list(int)
name: str
params: dict{str:Parameter}
attrs: dict{str:Attribute}
Attribute:
type: int // 0=null 1=f32 2=f64 3=f16 4=i32 5=i64 6=i16 7=i8 8=u8 9=bool
shape: list(int)
data: list(char)
Parameter:
type: int //0=null 1=b 2=i 3=f 4=s 5=ai 6=af 7=as 8=others
b: bool type==1
i: int type==2
f: float type==3
ai: list(int) type==4
af: list(float) type==5
s: str type==6
as: list(str) type==7
"""

def LoadModel(self, params_path: str, bin_path: str):
"""
Args:
params_path (str): the path of pnnx.params
bin_path (str): the path of pnnx.bin
Returns:
operators list(Operator)
operands list(Operand)
input_ops list(Operator)
output_ops list(Operator)
"""
a = graph.loadModel(params_path,bin_path)
assert(a is True, "please check your you input path")
operators = graph.getOperators()
operands = graph.getOperands()
input_ops = graph.getInputOps()
output_ops = graph.getOutputOps()
return operators, operands, input_ops, output_ops

def getNvpPnnxModel(self, pt_path_str: str, input_shape_str: str, custom_op_path_str: str = 'None', infer_py_path: str = 'None'):
"""_summary_
Args:
pt_path_str (str): the path of pt
input_shape_str (str): the shape of input
custom_op_path_str (str, optional): the path of custom op. Defaults to 'None'.
infer_py_path (str, optional): the path of exeutor. Defaults to 'None'.
Returns:
operators list(Operator)
operands list(Operand)
input_ops list(Operator)
output_ops list(Operator)
"""
result = graph.getNvpPnnxModel(pt_path_str, input_shape_str, custom_op_path_str, infer_py_path)
assert(result, "get pnnx model failed")
params_path = pt_path_str.replace('.pt','.pnnx.param')
bin_path = pt_path_str.replace('.pt','.pnnx.bin')
return self.LoadModel(params_path, bin_path)

def extract(self, params_path: str, infer_path: str, output_path :str, start_tensor_name: str, end_tensor_name: str = 'None'):
"""T he first version of the subgraph extraction function \
mainly records the information of the starting row and the ending row to construct the subgraph.
Args:
params_path (str): the path of params file
infer_path (str): the path of infer py
output_path (str): the output path of new params file and new infer py
start_tensor_name (str): the name of input tensor in sub graph
end_tensor_name (str): the name of output tensor in sub graph. Defaults to 'None'.
"""
assert os.path.exists(params_path), 'the path of params: {} is not exist'.format(params_path)
assert os.path.exists(infer_path), 'the path of infer py: {} is not exist'.format(infer_path)
assert os.path.isdir(output_path), 'the path of output: {} is not exist'.format(output_path)
with open(params_path, mode= 'r') as f:
params_lines = f.readlines()
input_shape = ''
output_shape = ''
start_index = -1
end_index = -1
adjacency_matrix_dict = {}
for index, line in enumerate(params_lines):
if index < 2:
continue
print(line.split())
op_info_list = line.split()
op_name = op_info_list[1]
input_nums = op_info_list[2]
output_nums = op_info_list[3]
input_name_list = []
for i in range(int(input_nums)):
input_name_list.append(op_info_list[4 + i])
output_name_list = []
for j in range(int(output_nums)):
output_name_list.append(op_info_list[4 + int(input_nums) + j])

if len(input_name_list) == 0:
# input node or attribute
if start_index != -1:
adjacency_matrix_dict[output_name] = {'input_names':input_name_list, 'index_info':[index]}
continue
if start_tensor_name == input_name_list[0] and len(input_name_list) == 1:
start_index = index
#get input shape
tensor_shape = []
for op_info in op_info_list:
if op_info.startswith('#'):
tensor_shape.append(op_info)
input_shape = tensor_shape[0]

if end_tensor_name == output_name_list[0] and len(output_nums) == 1:
end_index = index
for output_name in output_name_list:
adjacency_matrix_dict[output_name] = {'input_names':input_name_list, 'index_info':[index]}
#get output shape
tensor_shape = []
for op_info in op_info_list:
if op_info.startswith('#'):
tensor_shape.append(op_info)
output_shape = tensor_shape[-1]


break
if start_index !=-1:
for output_name in output_name_list:
if output_name in adjacency_matrix_dict:
adjacency_matrix_dict[output_name]['input_names'].extend(input_name_list)
adjacency_matrix_dict[output_name]['index_info'].append(index)
else:
adjacency_matrix_dict[output_name] = {'input_names':input_name_list, 'index_info':[index]}

if start_index > end_index or start_index == -1 or end_index == -1:
assert False, 'please check your start_tensor_name and end_tensor_name!'

# from end node to start node to get real sub graph
cur_input_name_list = adjacency_matrix_dict[end_tensor_name]['input_names']
cur_index_info_list = adjacency_matrix_dict[end_tensor_name]['index_info']

tmp_index_list = []

def backtrack(cur_input_name_list, cur_index_info_list):
if start_tensor_name in cur_input_name_list:
tmp_index_list.extend(cur_index_info_list)
return
for cur_input in cur_input_name_list:
if cur_input in adjacency_matrix_dict:
new_cur_input_name_list = adjacency_matrix_dict[cur_input]['input_names']
index_info_list = adjacency_matrix_dict[cur_input]['index_info']
new_cur_index_info_list =cur_index_info_list.copy()
new_cur_index_info_list.extend(index_info_list)
backtrack(new_cur_input_name_list,new_cur_index_info_list)

backtrack(cur_input_name_list,cur_index_info_list)

#sorted
unique_list = list(set(tmp_index_list))
sorted_list = sorted(unique_list)
new_params_lines = [params_lines[i] for i in sorted_list]
#get sub_graph all tensor
sub_graph_all_tensor = []
for new_param_line in new_params_lines:
new_op_info_list = new_param_line.split()
new_op_name = new_op_info_list[1]
new_op_input_nums = new_op_info_list[2]
new_op_output_nums = new_op_info_list[3]
new_op_output_name_list = []
for j in range(int(new_op_output_nums)):
new_op_output_name_list.append(new_op_info_list[4 + int(new_op_input_nums) + j])
sub_graph_all_tensor.extend(new_op_output_name_list)
sub_graph_all_tensor = list(set(sub_graph_all_tensor))
#insert output node
input_node = 'pnnx.Input pnnx_input_0 0 1 {} {}\n'.format(start_tensor_name, input_shape)
output_node = 'pnnx.Output pnnx_output_0 1 0 {} {}\n'.format(end_tensor_name, output_shape)
new_params_lines.insert(0,input_node)
new_params_lines.append(output_node)
ops_nums = len(new_params_lines)
tensor_nums = ops_nums - 1
new_params_lines.insert(0,'{} {}\n'.format(ops_nums,tensor_nums))
new_params_lines.insert(0,'7767517\n')
new_params_file_path = os.path.join(output_path,'{}-{}_mode.pnnx.param'.format(start_tensor_name, end_tensor_name))
with open(new_params_file_path, 'w') as file:
file.writelines(new_params_lines)
new_infer_file_path = os.path.join(output_path,'{}-{}_mode_pnnx_infer.py'.format(start_tensor_name, end_tensor_name))
self.Extract_forward(infer_path, start_tensor_name, end_tensor_name, input_shape, new_infer_file_path,sub_graph_all_tensor)

def Extract_forward(self, infer_path, start_tensor_name, end_tensor_name, input_shape, new_infer_file_path, sub_graph_all_tensor):

with open(infer_path, mode= 'r') as f:
infer_lines = f.readlines()
#get input shape line index
#get forwrad start_index
#get forward end_index
#get start_tensor_index
#get end_tensor_index
#get output_tensor_index
for infer_line_index, infer_line_info in enumerate(infer_lines):
infer_line_info_list = infer_line_info.split()
if infer_line_info_list == ['def','getInput(self,):']:
input_shape_line_index = infer_line_index + 1
elif infer_line_info_list == ['def', 'forward(self,', 'v_0):']:
forwrad_start_index = infer_line_index
elif infer_line_info_list == ['return', 'intermediate']:
forwrad_end_index = infer_line_index
elif len(infer_line_info_list) > 1 and infer_line_info_list[0] == 'v_' + start_tensor_name:
start_tensor_index = infer_line_index + 1
elif len(infer_line_info_list) > 1 and infer_line_info_list[0] == 'v_' + end_tensor_name:
end_tensor_index = infer_line_index
elif infer_line_info_list == ['if', 'self.infer_flag:']:
output_tensor_index = infer_line_index + 1

#update input shape
infer_lines[input_shape_line_index] = ' return [[' + extract_content_between_parentheses(input_shape)[0] + ']]\n'
#update forward tensor name
infer_lines[forwrad_start_index] =' def forward(self, v_{}):\n'.format(start_tensor_name)
#update output_tensor_index
infer_lines[output_tensor_index] =' return v_{}\n'.format(end_tensor_name)
#get sub graph forward
sub_graph_line_indexs = []
for forward_index in range(start_tensor_index,end_tensor_index + 1):
o_tensor = infer_lines[forward_index].split()[0].split('_')[1]
if o_tensor in sub_graph_all_tensor:
sub_graph_line_indexs.append(forward_index)
sub_graph_line_indexs = list(set(sub_graph_line_indexs))
sub_graph_line_indexs = sorted(sub_graph_line_indexs)
#get new infer_lines
sub_graph_lines = [infer_lines[i] for i in sub_graph_line_indexs]
new_infer_lines = infer_lines[:forwrad_start_index + 1] + sub_graph_lines + \
infer_lines[output_tensor_index - 1: ]
with open(new_infer_file_path, 'w') as file:
file.writelines(new_infer_lines)


if __name__ == "__main__":

# line = ' def forward(self, v_0):\n'
# print(line.split())

parser = PnnxParser()
params_path = "/workspace/trans_onnx/project/new_project/ncnn/tools/pnnx/model_zoo/n16x_torch_fcn_r101b_d8_4xb2_80k_cityscapes_512x1024/model.pnnx.param"
infer_path = "/workspace/trans_onnx/project/new_project/ncnn/tools/pnnx/model_zoo/n16x_torch_fcn_r101b_d8_4xb2_80k_cityscapes_512x1024/model_pnnx_infer.py"
output_path = "/workspace/trans_onnx/project/new_project/ncnn/tools/pnnx/model_zoo/n16x_torch_fcn_r101b_d8_4xb2_80k_cityscapes_512x1024/output"
start_tensor_name = '216'
end_tensor_name = '248'
parser.extract(params_path, infer_path, output_path, start_tensor_name, end_tensor_name)


# with open('conv.pnnx.param', mode = 'w') as f:
# fprintf(paramfp, "%-24s %-24s %d %d", op->type.c_str(), op->name.c_str(), (int)op->inputs.size(), (int)op->outputs.size());
# with open(template_output_path, 'w') as file:
# file.writelines(original_lines)
# with open('conv.pnnx.param', "w", encoding="utf-8") as f:
# f.write("import os")
# a = [1,3,4,5]
# b = [23]
# c = b.copy()
# c.extend(a)
# print(a)
# print(b)
# print(c)
pass

0 comments on commit 022d9c3

Please sign in to comment.