diff --git a/tools/pnnx/CMakeLists.txt b/tools/pnnx/CMakeLists.txt index 5c6c798767d..836fbc38f4c 100644 --- a/tools/pnnx/CMakeLists.txt +++ b/tools/pnnx/CMakeLists.txt @@ -29,6 +29,13 @@ set(CMAKE_CXX_STANDARD 14) # set(CMAKE_BUILD_TYPE relwithdebinfo) # set(CMAKE_BUILD_TYPE release) +string(TOUPPER "${CMAKE_BUILD_TYPE}" BUILD_TYPE_UPPER) +if(${BUILD_TYPE_UPPER} STREQUAL "DEBUG") + add_compile_definitions(DEBUG) +else() + add_compile_definitions(NDEBUG) +endif() + option(PNNX_COVERAGE "build for coverage" OFF) # set(Torch_INSTALL_DIR "/home/nihui/osd/pnnx/install" CACHE STRING "") diff --git a/tools/pnnx/README.md b/tools/pnnx/README.md index d7f40eb03ed..bd07ba7aeb4 100644 --- a/tools/pnnx/README.md +++ b/tools/pnnx/README.md @@ -717,6 +717,41 @@ TORCH_LIBRARY(upfirdn2d_op, m) { ] } ``` + +2. 通过pnnx可执行文件直接进行c++调试 +```json + +{ + "version": "0.2.0", + "configurations": [ + + { + "name": "msvc", + "type": "cppvsdbg", + "request": "launch", + // "program": "${workspaceFolder}/bin/test_ReduceL1_wrapper.exe", + "program": "${workspaceFolder}/python/build/lib.win-amd64-cpython-38/pnnx/Debug/pnnx.exe", + + // "args": ["D:\\project\\programs\\ncnn_project\\nvppnnx\\model_zoo\\segformer\\model.pt", + // "inputshape=[1,3,512,512]","start_nodes=597,591,585,579","end_nodes=598"], + + "args": ["D:/project/programs/my_project/tests/test_python/test_op/model_zoo3/script_test/test.pt", + "D:/project/programs/ncnn_project/nvppnnx/model_zoo/script_test", + "inputshape=[2,3],[1]i64"], + "stopAtEntry": false, + "cwd": "${workspaceFolder}", + "environment": [], + "externalConsole": false, + // "preLaunchTask": "task of build with msvc" + } +] +} +``` + +修改了输入参数,第二个参数为保存pnnx的路径。 + +新增"extract_model_name"输入参数,指定拆分网络的name,可以是主网络也可是是子网络。 + # 添加自定义算子实例 在 mytests文件夹下添加了自定义算子实例,可以参考 |op_name|path| diff --git a/tools/pnnx/Releasenotes b/tools/pnnx/Releasenotes index 8d8a62e88cb..4613d8c4c0a 100644 --- a/tools/pnnx/Releasenotes +++ b/tools/pnnx/Releasenotes @@ -63,4 +63,10 @@ dev.1.0.18.20240613 1. Skip conv2d nodes of type NoneType dev.1.0.19.20240614 -1. Add extracting sub graph function \ No newline at end of file +1. Add extracting sub graph function + +dev.1.0.20.20240617 +1. Add loop op parse function + +dev.1.0.21.20240619 +1. Support export sub_model \ No newline at end of file diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index a7b9385c4bc..13826f6975e 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -385,200 +385,201 @@ set(pnnx_pass_level6_SRCS pass_level6/trans_Stack2Unsqueeze.cpp pass_level6/trans_ReshapeAs2Reshape.cpp pass_level6/trans_TensorTypeAs2TensorTo.cpp + pass_level6/fold_Loop.cpp ) -set(pnnx_pass_ncnn_SRCS - pass_ncnn/convert_attribute.cpp - pass_ncnn/convert_custom_op.cpp - pass_ncnn/convert_module_op.cpp - pass_ncnn/convert_half_to_float.cpp - pass_ncnn/convert_input.cpp - pass_ncnn/convert_torch_cat.cpp - pass_ncnn/convert_torch_chunk.cpp - pass_ncnn/convert_torch_einsum.cpp - pass_ncnn/convert_torch_split.cpp - pass_ncnn/convert_torch_stack.cpp - pass_ncnn/convert_torch_tensor_split.cpp - pass_ncnn/convert_torch_unbind.cpp - pass_ncnn/convert_Tensor_select.cpp - pass_ncnn/convert_Tensor_slice.cpp - pass_ncnn/convert_Tensor_slice_copy.cpp - pass_ncnn/eliminate_output.cpp - pass_ncnn/expand_expression.cpp - pass_ncnn/fuse_convert_shufflechannel_slice.cpp - pass_ncnn/insert_split.cpp - pass_ncnn/chain_multi_output.cpp - pass_ncnn/solve_batch_index.cpp - - pass_ncnn/eliminate_noop.cpp - pass_ncnn/eliminate_tail_reshape_permute.cpp - pass_ncnn/fuse_convolution_activation.cpp - pass_ncnn/fuse_convolution1d_activation.cpp - pass_ncnn/fuse_convolutiondepthwise_activation.cpp - pass_ncnn/fuse_convolutiondepthwise1d_activation.cpp - pass_ncnn/fuse_deconvolution_activation.cpp - pass_ncnn/fuse_deconvolutiondepthwise_activation.cpp - pass_ncnn/fuse_innerproduct_activation.cpp - pass_ncnn/fuse_transpose_matmul.cpp - pass_ncnn/fuse_binaryop_eltwise.cpp - pass_ncnn/insert_reshape_numpy_binaryop_broadcast.cpp - pass_ncnn/insert_reshape_linear.cpp - pass_ncnn/insert_reshape_pooling.cpp - pass_ncnn/insert_reshape_global_pooling.cpp - - pass_ncnn/F_adaptive_avg_pool1d.cpp - pass_ncnn/F_adaptive_avg_pool2d.cpp - pass_ncnn/F_adaptive_avg_pool3d.cpp - pass_ncnn/F_adaptive_max_pool1d.cpp - pass_ncnn/F_adaptive_max_pool2d.cpp - pass_ncnn/F_adaptive_max_pool3d.cpp - pass_ncnn/F_avg_pool1d.cpp - pass_ncnn/F_avg_pool2d.cpp - pass_ncnn/F_avg_pool3d.cpp - pass_ncnn/F_batch_norm.cpp - pass_ncnn/F_celu.cpp - pass_ncnn/F_conv_transpose1d.cpp - pass_ncnn/F_conv_transpose2d.cpp - pass_ncnn/F_conv_transpose3d.cpp - pass_ncnn/F_conv1d.cpp - pass_ncnn/F_conv2d.cpp - pass_ncnn/F_conv3d.cpp - pass_ncnn/F_elu.cpp - pass_ncnn/F_embedding.cpp - pass_ncnn/F_fold.cpp - pass_ncnn/F_gelu.cpp - pass_ncnn/F_glu.cpp - pass_ncnn/F_grid_sample.cpp - pass_ncnn/F_group_norm.cpp - pass_ncnn/F_hardsigmoid.cpp - pass_ncnn/F_hardswish.cpp - pass_ncnn/F_hardtanh.cpp - pass_ncnn/F_instance_norm.cpp - pass_ncnn/F_interpolate.cpp - pass_ncnn/F_layer_norm.cpp - pass_ncnn/F_leaky_relu.cpp - pass_ncnn/F_linear.cpp - pass_ncnn/F_local_response_norm.cpp - pass_ncnn/F_log_softmax.cpp - pass_ncnn/F_logsigmoid.cpp - pass_ncnn/F_max_pool1d.cpp - pass_ncnn/F_max_pool2d.cpp - pass_ncnn/F_max_pool3d.cpp - pass_ncnn/F_mish.cpp - pass_ncnn/F_normalize.cpp - pass_ncnn/F_pad.cpp - pass_ncnn/F_pixel_shuffle.cpp - pass_ncnn/F_pixel_unshuffle.cpp - pass_ncnn/F_prelu.cpp - pass_ncnn/F_relu.cpp - pass_ncnn/F_relu6.cpp - pass_ncnn/F_selu.cpp - pass_ncnn/F_sigmoid.cpp - pass_ncnn/F_silu.cpp - pass_ncnn/F_softmax.cpp - pass_ncnn/F_tanh.cpp - pass_ncnn/F_unfold.cpp - pass_ncnn/F_upsample_bilinear.cpp - pass_ncnn/F_upsample_nearest.cpp - pass_ncnn/F_upsample.cpp - pass_ncnn/nn_AdaptiveAvgPool1d.cpp - pass_ncnn/nn_AdaptiveAvgPool2d.cpp - pass_ncnn/nn_AdaptiveAvgPool3d.cpp - pass_ncnn/nn_AdaptiveMaxPool1d.cpp - pass_ncnn/nn_AdaptiveMaxPool2d.cpp - pass_ncnn/nn_AdaptiveMaxPool3d.cpp - pass_ncnn/nn_AvgPool1d.cpp - pass_ncnn/nn_AvgPool2d.cpp - pass_ncnn/nn_AvgPool3d.cpp - pass_ncnn/nn_BatchNorm1d.cpp - pass_ncnn/nn_BatchNorm2d.cpp - pass_ncnn/nn_BatchNorm3d.cpp - pass_ncnn/nn_CELU.cpp - pass_ncnn/nn_ChannelShuffle.cpp - pass_ncnn/nn_ConstantPad1d.cpp - pass_ncnn/nn_ConstantPad2d.cpp - pass_ncnn/nn_ConstantPad3d.cpp - pass_ncnn/nn_Conv1d.cpp - pass_ncnn/nn_Conv2d.cpp - pass_ncnn/nn_Conv3d.cpp - pass_ncnn/nn_ConvTranspose1d.cpp - pass_ncnn/nn_ConvTranspose2d.cpp - pass_ncnn/nn_ConvTranspose3d.cpp - pass_ncnn/nn_ELU.cpp - pass_ncnn/nn_Embedding.cpp - pass_ncnn/nn_Fold.cpp - pass_ncnn/nn_GELU.cpp - pass_ncnn/nn_GLU.cpp - pass_ncnn/nn_GroupNorm.cpp - pass_ncnn/nn_GRU.cpp - pass_ncnn/nn_Hardsigmoid.cpp - pass_ncnn/nn_Hardswish.cpp - pass_ncnn/nn_Hardtanh.cpp - pass_ncnn/nn_InstanceNorm2d.cpp - pass_ncnn/nn_LayerNorm.cpp - pass_ncnn/nn_LeakyReLU.cpp - pass_ncnn/nn_Linear.cpp - pass_ncnn/nn_LocalResponseNorm.cpp - pass_ncnn/nn_LogSigmoid.cpp - pass_ncnn/nn_LogSoftmax.cpp - pass_ncnn/nn_LSTM.cpp - pass_ncnn/nn_MaxPool1d.cpp - pass_ncnn/nn_MaxPool2d.cpp - pass_ncnn/nn_MaxPool3d.cpp - pass_ncnn/nn_Mish.cpp - pass_ncnn/nn_MultiheadAttention.cpp - pass_ncnn/nn_PixelShuffle.cpp - pass_ncnn/nn_PixelUnshuffle.cpp - pass_ncnn/nn_PReLU.cpp - pass_ncnn/nn_ReflectionPad1d.cpp - pass_ncnn/nn_ReflectionPad2d.cpp - pass_ncnn/nn_ReLU.cpp - pass_ncnn/nn_ReLU6.cpp - pass_ncnn/nn_ReplicationPad1d.cpp - pass_ncnn/nn_ReplicationPad2d.cpp - pass_ncnn/nn_ReplicationPad3d.cpp - pass_ncnn/nn_RNN.cpp - pass_ncnn/nn_SELU.cpp - pass_ncnn/nn_Sigmoid.cpp - pass_ncnn/nn_SiLU.cpp - pass_ncnn/nn_Softmax.cpp - pass_ncnn/nn_Softmax2d.cpp - pass_ncnn/nn_Tanh.cpp - pass_ncnn/nn_Unfold.cpp - pass_ncnn/nn_Upsample.cpp - pass_ncnn/nn_UpsamplingBilinear2d.cpp - pass_ncnn/nn_UpsamplingNearest2d.cpp - pass_ncnn/nn_ZeroPad2d.cpp - pass_ncnn/Tensor_contiguous.cpp - pass_ncnn/Tensor_reshape.cpp - pass_ncnn/Tensor_repeat.cpp - pass_ncnn/Tensor_view.cpp - pass_ncnn/torch_addmm.cpp - pass_ncnn/torch_amax.cpp - pass_ncnn/torch_amin.cpp - pass_ncnn/torch_bmm.cpp - pass_ncnn/torch_clamp.cpp - pass_ncnn/torch_clone.cpp - pass_ncnn/torch_cumsum.cpp - pass_ncnn/torch_diag.cpp - pass_ncnn/torch_flatten.cpp - pass_ncnn/torch_logsumexp.cpp - pass_ncnn/torch_matmul.cpp - pass_ncnn/torch_max.cpp - pass_ncnn/torch_mean.cpp - pass_ncnn/torch_min.cpp - pass_ncnn/torch_mm.cpp - pass_ncnn/torch_norm.cpp - pass_ncnn/torch_permute.cpp - pass_ncnn/torch_prod.cpp - pass_ncnn/torch_slice_scatter.cpp - pass_ncnn/torch_squeeze.cpp - pass_ncnn/torch_sum.cpp - pass_ncnn/torch_t.cpp - pass_ncnn/torch_transpose.cpp - pass_ncnn/torch_unsqueeze.cpp - pass_ncnn/torchvision_DeformConv2d.cpp -) +# set(pnnx_pass_ncnn_SRCS +# pass_ncnn/convert_attribute.cpp +# pass_ncnn/convert_custom_op.cpp +# pass_ncnn/convert_module_op.cpp +# pass_ncnn/convert_half_to_float.cpp +# pass_ncnn/convert_input.cpp +# pass_ncnn/convert_torch_cat.cpp +# pass_ncnn/convert_torch_chunk.cpp +# pass_ncnn/convert_torch_einsum.cpp +# pass_ncnn/convert_torch_split.cpp +# pass_ncnn/convert_torch_stack.cpp +# pass_ncnn/convert_torch_tensor_split.cpp +# pass_ncnn/convert_torch_unbind.cpp +# pass_ncnn/convert_Tensor_select.cpp +# pass_ncnn/convert_Tensor_slice.cpp +# pass_ncnn/convert_Tensor_slice_copy.cpp +# pass_ncnn/eliminate_output.cpp +# pass_ncnn/expand_expression.cpp +# pass_ncnn/fuse_convert_shufflechannel_slice.cpp +# pass_ncnn/insert_split.cpp +# pass_ncnn/chain_multi_output.cpp +# pass_ncnn/solve_batch_index.cpp + +# pass_ncnn/eliminate_noop.cpp +# pass_ncnn/eliminate_tail_reshape_permute.cpp +# pass_ncnn/fuse_convolution_activation.cpp +# pass_ncnn/fuse_convolution1d_activation.cpp +# pass_ncnn/fuse_convolutiondepthwise_activation.cpp +# pass_ncnn/fuse_convolutiondepthwise1d_activation.cpp +# pass_ncnn/fuse_deconvolution_activation.cpp +# pass_ncnn/fuse_deconvolutiondepthwise_activation.cpp +# pass_ncnn/fuse_innerproduct_activation.cpp +# pass_ncnn/fuse_transpose_matmul.cpp +# pass_ncnn/fuse_binaryop_eltwise.cpp +# pass_ncnn/insert_reshape_numpy_binaryop_broadcast.cpp +# pass_ncnn/insert_reshape_linear.cpp +# pass_ncnn/insert_reshape_pooling.cpp +# pass_ncnn/insert_reshape_global_pooling.cpp + +# pass_ncnn/F_adaptive_avg_pool1d.cpp +# pass_ncnn/F_adaptive_avg_pool2d.cpp +# pass_ncnn/F_adaptive_avg_pool3d.cpp +# pass_ncnn/F_adaptive_max_pool1d.cpp +# pass_ncnn/F_adaptive_max_pool2d.cpp +# pass_ncnn/F_adaptive_max_pool3d.cpp +# pass_ncnn/F_avg_pool1d.cpp +# pass_ncnn/F_avg_pool2d.cpp +# pass_ncnn/F_avg_pool3d.cpp +# pass_ncnn/F_batch_norm.cpp +# pass_ncnn/F_celu.cpp +# pass_ncnn/F_conv_transpose1d.cpp +# pass_ncnn/F_conv_transpose2d.cpp +# pass_ncnn/F_conv_transpose3d.cpp +# pass_ncnn/F_conv1d.cpp +# pass_ncnn/F_conv2d.cpp +# pass_ncnn/F_conv3d.cpp +# pass_ncnn/F_elu.cpp +# pass_ncnn/F_embedding.cpp +# pass_ncnn/F_fold.cpp +# pass_ncnn/F_gelu.cpp +# pass_ncnn/F_glu.cpp +# pass_ncnn/F_grid_sample.cpp +# pass_ncnn/F_group_norm.cpp +# pass_ncnn/F_hardsigmoid.cpp +# pass_ncnn/F_hardswish.cpp +# pass_ncnn/F_hardtanh.cpp +# pass_ncnn/F_instance_norm.cpp +# pass_ncnn/F_interpolate.cpp +# pass_ncnn/F_layer_norm.cpp +# pass_ncnn/F_leaky_relu.cpp +# pass_ncnn/F_linear.cpp +# pass_ncnn/F_local_response_norm.cpp +# pass_ncnn/F_log_softmax.cpp +# pass_ncnn/F_logsigmoid.cpp +# pass_ncnn/F_max_pool1d.cpp +# pass_ncnn/F_max_pool2d.cpp +# pass_ncnn/F_max_pool3d.cpp +# pass_ncnn/F_mish.cpp +# pass_ncnn/F_normalize.cpp +# pass_ncnn/F_pad.cpp +# pass_ncnn/F_pixel_shuffle.cpp +# pass_ncnn/F_pixel_unshuffle.cpp +# pass_ncnn/F_prelu.cpp +# pass_ncnn/F_relu.cpp +# pass_ncnn/F_relu6.cpp +# pass_ncnn/F_selu.cpp +# pass_ncnn/F_sigmoid.cpp +# pass_ncnn/F_silu.cpp +# pass_ncnn/F_softmax.cpp +# pass_ncnn/F_tanh.cpp +# pass_ncnn/F_unfold.cpp +# pass_ncnn/F_upsample_bilinear.cpp +# pass_ncnn/F_upsample_nearest.cpp +# pass_ncnn/F_upsample.cpp +# pass_ncnn/nn_AdaptiveAvgPool1d.cpp +# pass_ncnn/nn_AdaptiveAvgPool2d.cpp +# pass_ncnn/nn_AdaptiveAvgPool3d.cpp +# pass_ncnn/nn_AdaptiveMaxPool1d.cpp +# pass_ncnn/nn_AdaptiveMaxPool2d.cpp +# pass_ncnn/nn_AdaptiveMaxPool3d.cpp +# pass_ncnn/nn_AvgPool1d.cpp +# pass_ncnn/nn_AvgPool2d.cpp +# pass_ncnn/nn_AvgPool3d.cpp +# pass_ncnn/nn_BatchNorm1d.cpp +# pass_ncnn/nn_BatchNorm2d.cpp +# pass_ncnn/nn_BatchNorm3d.cpp +# pass_ncnn/nn_CELU.cpp +# pass_ncnn/nn_ChannelShuffle.cpp +# pass_ncnn/nn_ConstantPad1d.cpp +# pass_ncnn/nn_ConstantPad2d.cpp +# pass_ncnn/nn_ConstantPad3d.cpp +# pass_ncnn/nn_Conv1d.cpp +# pass_ncnn/nn_Conv2d.cpp +# pass_ncnn/nn_Conv3d.cpp +# pass_ncnn/nn_ConvTranspose1d.cpp +# pass_ncnn/nn_ConvTranspose2d.cpp +# pass_ncnn/nn_ConvTranspose3d.cpp +# pass_ncnn/nn_ELU.cpp +# pass_ncnn/nn_Embedding.cpp +# pass_ncnn/nn_Fold.cpp +# pass_ncnn/nn_GELU.cpp +# pass_ncnn/nn_GLU.cpp +# pass_ncnn/nn_GroupNorm.cpp +# pass_ncnn/nn_GRU.cpp +# pass_ncnn/nn_Hardsigmoid.cpp +# pass_ncnn/nn_Hardswish.cpp +# pass_ncnn/nn_Hardtanh.cpp +# pass_ncnn/nn_InstanceNorm2d.cpp +# pass_ncnn/nn_LayerNorm.cpp +# pass_ncnn/nn_LeakyReLU.cpp +# pass_ncnn/nn_Linear.cpp +# pass_ncnn/nn_LocalResponseNorm.cpp +# pass_ncnn/nn_LogSigmoid.cpp +# pass_ncnn/nn_LogSoftmax.cpp +# pass_ncnn/nn_LSTM.cpp +# pass_ncnn/nn_MaxPool1d.cpp +# pass_ncnn/nn_MaxPool2d.cpp +# pass_ncnn/nn_MaxPool3d.cpp +# pass_ncnn/nn_Mish.cpp +# pass_ncnn/nn_MultiheadAttention.cpp +# pass_ncnn/nn_PixelShuffle.cpp +# pass_ncnn/nn_PixelUnshuffle.cpp +# pass_ncnn/nn_PReLU.cpp +# pass_ncnn/nn_ReflectionPad1d.cpp +# pass_ncnn/nn_ReflectionPad2d.cpp +# pass_ncnn/nn_ReLU.cpp +# pass_ncnn/nn_ReLU6.cpp +# pass_ncnn/nn_ReplicationPad1d.cpp +# pass_ncnn/nn_ReplicationPad2d.cpp +# pass_ncnn/nn_ReplicationPad3d.cpp +# pass_ncnn/nn_RNN.cpp +# pass_ncnn/nn_SELU.cpp +# pass_ncnn/nn_Sigmoid.cpp +# pass_ncnn/nn_SiLU.cpp +# pass_ncnn/nn_Softmax.cpp +# pass_ncnn/nn_Softmax2d.cpp +# pass_ncnn/nn_Tanh.cpp +# pass_ncnn/nn_Unfold.cpp +# pass_ncnn/nn_Upsample.cpp +# pass_ncnn/nn_UpsamplingBilinear2d.cpp +# pass_ncnn/nn_UpsamplingNearest2d.cpp +# pass_ncnn/nn_ZeroPad2d.cpp +# pass_ncnn/Tensor_contiguous.cpp +# pass_ncnn/Tensor_reshape.cpp +# pass_ncnn/Tensor_repeat.cpp +# pass_ncnn/Tensor_view.cpp +# pass_ncnn/torch_addmm.cpp +# pass_ncnn/torch_amax.cpp +# pass_ncnn/torch_amin.cpp +# pass_ncnn/torch_bmm.cpp +# pass_ncnn/torch_clamp.cpp +# pass_ncnn/torch_clone.cpp +# pass_ncnn/torch_cumsum.cpp +# pass_ncnn/torch_diag.cpp +# pass_ncnn/torch_flatten.cpp +# pass_ncnn/torch_logsumexp.cpp +# pass_ncnn/torch_matmul.cpp +# pass_ncnn/torch_max.cpp +# pass_ncnn/torch_mean.cpp +# pass_ncnn/torch_min.cpp +# pass_ncnn/torch_mm.cpp +# pass_ncnn/torch_norm.cpp +# pass_ncnn/torch_permute.cpp +# pass_ncnn/torch_prod.cpp +# pass_ncnn/torch_slice_scatter.cpp +# pass_ncnn/torch_squeeze.cpp +# pass_ncnn/torch_sum.cpp +# pass_ncnn/torch_t.cpp +# pass_ncnn/torch_transpose.cpp +# pass_ncnn/torch_unsqueeze.cpp +# pass_ncnn/torchvision_DeformConv2d.cpp +# ) if(PROTOBUF_FOUND) if(DEFINED PROTOBUF_VERSION AND PROTOBUF_VERSION VERSION_GREATER_EQUAL 3.22) @@ -679,6 +680,7 @@ if(NOT MSVC) endif() set(pnnx_SRCS + config.cpp main.cpp ir.cpp storezip.cpp @@ -696,12 +698,12 @@ set(pnnx_SRCS ${pnnx_pass_level5_SRCS} ${pnnx_pass_level6_SRCS} - pass_ncnn.cpp - save_ncnn.cpp - ${pnnx_pass_ncnn_SRCS} + # pass_ncnn.cpp + # save_ncnn.cpp + # ${pnnx_pass_ncnn_SRCS} ) - +# add_executable(pnnx ${pnnx_SRCS}) file(GLOB_RECURSE SRC_PARSE_FILES ${CMAKE_CURRENT_SOURCE_DIR}/parse/*.cpp ) diff --git a/tools/pnnx/src/config.cpp b/tools/pnnx/src/config.cpp new file mode 100644 index 00000000000..3ccfe7e2439 --- /dev/null +++ b/tools/pnnx/src/config.cpp @@ -0,0 +1,5 @@ +// config.cpp +#include "config.h" + +// the flag of dynamic network +bool dynamic_network = false; \ No newline at end of file diff --git a/tools/pnnx/src/config.h b/tools/pnnx/src/config.h new file mode 100644 index 00000000000..58fc945d3a0 --- /dev/null +++ b/tools/pnnx/src/config.h @@ -0,0 +1,8 @@ +// config.h +#ifndef CONFIG_H +#define CONFIG_H + +// the flag of dynamic network +extern bool dynamic_network; + +#endif // CONFIG_H \ No newline at end of file diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index c6c3f5e3d95..b5bb7d74cb5 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -27,7 +27,7 @@ #include #include "storezip.h" #include "utils.h" - +#include namespace pnnx { static bool type_is_integer(int type) @@ -2785,7 +2785,8 @@ std::vector getDirectoryPath(const std::string& filePath) int Graph::python_infer(const std::string& pypath, const std::string& binpath, const std::vector& customop_modules, std::set& custom_ops, - std::string& customop_infer_py) + std::string& customop_infer_py, + std::string& save_dir) { FILE* pyfp = fopen(pypath.c_str(), "wb"); if (!pyfp) @@ -2801,6 +2802,7 @@ int Graph::python_infer(const std::string& pypath, const std::string& binpath, fprintf(pyfp, "import torch\n"); fprintf(pyfp, "import torch.nn as nn\n"); fprintf(pyfp, "import torch.nn.functional as F\n"); + fprintf(pyfp, "import importlib\n"); fprintf(pyfp, "try:\n"); fprintf(pyfp, " import torchvision\n"); fprintf(pyfp, "except:\n"); @@ -2867,6 +2869,15 @@ int Graph::python_infer(const std::string& pypath, const std::string& binpath, } fprintf(pyfp, "\n"); + // load_module + { + fprintf(pyfp, "def load_module(module_path):\n"); + fprintf(pyfp, " spec = importlib.util.spec_from_file_location('module', module_path)\n"); + fprintf(pyfp, " module = importlib.util.module_from_spec(spec)\n"); + fprintf(pyfp, " spec.loader.exec_module(module)\n"); + fprintf(pyfp, " return module\n"); + fprintf(pyfp, "\n"); + } //add by senli[pnnx_infer] fprintf(pyfp, "class Model(nn.Module):\n"); fprintf(pyfp, " def __init__(self, bin_path, infer_flag = False):\n"); @@ -2880,6 +2891,18 @@ int Graph::python_infer(const std::string& pypath, const std::string& binpath, fprintf(pyfp, " self.infer_flag = infer_flag\n"); for (const Operator* op : ops) { + if(op->type == "pnnx.Loop") + { + std::string op_name = op->name; + + std::string subModelBinPath = save_dir + "/" + op_name + ".pnnx.bin"; + std::string subModelInferPath = save_dir + "/" + op_name + "_pnnx_infer.py"; + fprintf(pyfp, " %s = load_module('%s')\n", (op_name + "_Mod").c_str(), subModelInferPath.c_str()); + fprintf(pyfp, " %s = getattr(%s, 'Model')\n", (op_name + "_Cls").c_str(), (op_name + "_Mod").c_str()); + fprintf(pyfp, " %s = %s('%s', True)\n", ("self." + op_name + "_Obj").c_str(), (op_name + "_Cls").c_str(), subModelBinPath.c_str()); + fprintf(pyfp, " %s.eval()\n", ("self." + op_name + "_Obj").c_str()); + continue; + } if (op->type.substr(0, 3) != "nn." && op->type.substr(0, 16) != "torchvision.ops.") continue; @@ -3234,6 +3257,7 @@ int Graph::python_infer(const std::string& pypath, const std::string& binpath, { for (const Operator* op : ops) { + if (op->type == "pnnx.Input" || op->type == "pnnx.Output") continue; @@ -3242,6 +3266,46 @@ int Graph::python_infer(const std::string& pypath, const std::string& binpath, fprintf(pyfp, " "); + if(op->type == "pnnx.Loop") + { + std::string condition_expr = op->params.at("condition").s; + int iter_num = op->params.at("iter_num").i; + std::string op_name = op->name; + std::vector inputs = op->inputs; + std::vector outputs = op->outputs; + std::string output_list = ""; + std::string input_list = ""; + std::string real_input_list = ""; + for(int index = 0; index < outputs.size(); index++) + { + std::string cur_output_name = sanitize_identifier(op->outputs[index]->name); + std::string cur_input_name = sanitize_identifier(op->inputs[index]->name); + output_list = output_list + "v_" + cur_output_name; + input_list = input_list + "v_" + cur_input_name; + if (index + 1 != outputs.size()) + { + output_list = output_list + ", "; + input_list = input_list + ", "; + } + + } + for(int index = 0; index < inputs.size(); index++) + { + std::string cur_input_name = sanitize_identifier(op->inputs[index]->name); + real_input_list = real_input_list + "v_" + cur_input_name; + if (index + 1 != inputs.size()) + real_input_list = real_input_list + ", "; + } + fprintf(pyfp, "%s = %s\n", output_list.c_str(), input_list.c_str()); + fprintf(pyfp, " condition = %s\n", condition_expr.c_str()); + fprintf(pyfp, " i = 0\n"); + fprintf(pyfp, " while condition and i < %s:\n", std::to_string(iter_num).c_str()); + fprintf(pyfp, " %s = %s\n", input_list.c_str(), output_list.c_str()); + fprintf(pyfp, " %s = %s(%s)\n", output_list.c_str(), ("self." + op_name + "_Obj").c_str(), real_input_list.c_str()); + fprintf(pyfp, " i += 1\n"); + continue; + } + if (op->type == "pnnx.Expression") { // expr @@ -4417,6 +4481,16 @@ Operand* Graph::get_operand(const std::string& name) return 0; } +Operator* Graph::get_operator(const std::string& name) +{ + for (Operator* r : ops) + { + if (r->name == name) + return r; + } + + return 0; +} const Operand* Graph::get_operand(const std::string& name) const { for (const Operand* r : operands) diff --git a/tools/pnnx/src/ir.h b/tools/pnnx/src/ir.h index 148bea2cd6f..597227cc6b9 100644 --- a/tools/pnnx/src/ir.h +++ b/tools/pnnx/src/ir.h @@ -325,7 +325,7 @@ class Graph //add by senli[pnnx_infer] int python_infer(const std::string& pypath, const std::string& binpath, const std::vector& customop_modules, std::set& custom_ops, - std::string& customop_infer_py); + std::string& customop_infer_py, std::string& save_dir); int parse(const std::string& param); @@ -346,6 +346,9 @@ class Graph Operand* new_operand(const std::string& name); Operand* get_operand(const std::string& name); + + Operator* get_operator(const std::string& name); + const Operand* get_operand(const std::string& name) const; int extract_sub_graph(const std::vector& start_nodes, const std::vector& end_nodes); diff --git a/tools/pnnx/src/load_torchscript.cpp b/tools/pnnx/src/load_torchscript.cpp index 12cc4129fb2..ac8fc837182 100644 --- a/tools/pnnx/src/load_torchscript.cpp +++ b/tools/pnnx/src/load_torchscript.cpp @@ -429,7 +429,8 @@ const torch::jit::Node* find_node_by_kind(const std::shared_ptr>& pnnx_graph_map, const std::string& device, const std::vector >& input_shapes, const std::vector& input_types, @@ -544,7 +545,7 @@ int load_torchscript(const std::string& ptpath, Graph& pnnx_graph, fprintf(stderr, "############# pass_level1\n"); - pnnx::pass_level1(mod, g, module_operators, pnnx_graph); + pnnx::pass_level1(mod, g, module_operators, pnnx_graph_map); return 0; } diff --git a/tools/pnnx/src/load_torchscript.h b/tools/pnnx/src/load_torchscript.h index 31a8a421723..27af7043c39 100644 --- a/tools/pnnx/src/load_torchscript.h +++ b/tools/pnnx/src/load_torchscript.h @@ -14,12 +14,13 @@ #ifndef PNNX_LOAD_TORCHSCRIPT_H #define PNNX_LOAD_TORCHSCRIPT_H - +#include #include "ir.h" namespace pnnx { -int load_torchscript(const std::string& ptpath, Graph& g, +int load_torchscript(const std::string& ptpath, \ + std::unordered_map>& pnnx_graph_map, const std::string& device, const std::vector >& input_shapes, const std::vector& input_types, diff --git a/tools/pnnx/src/main.cpp b/tools/pnnx/src/main.cpp index 4a8f1865f61..4ea17e355f7 100644 --- a/tools/pnnx/src/main.cpp +++ b/tools/pnnx/src/main.cpp @@ -18,7 +18,9 @@ #include #include #include +#include +#include "config.h" #include "ir.h" #include "pass_level2.h" #include "pass_level3.h" @@ -33,7 +35,7 @@ #endif #include "pass_ncnn.h" -#include "save_ncnn.h" +// #include "save_ncnn.h" #if BUILD_PNNX2ONNX #include "save_onnx.h" @@ -206,18 +208,18 @@ int main(int argc, char** argv) } std::string ptpath = std::string(argv[1]); - + std::string save_dir = std::string(argv[2]); std::string ptbase = get_basename(ptpath); - std::string pnnxparampath = ptbase + ".pnnx.param"; - std::string pnnxbinpath = ptbase + ".pnnx.bin"; - std::string pnnxpypath = ptbase + "_pnnx.py"; + // std::string pnnxparampath = ptbase + ".pnnx.param"; + // std::string pnnxbinpath = ptbase + ".pnnx.bin"; + // std::string pnnxpypath = ptbase + "_pnnx.py"; // add by senli[pnnx_infer] - std::string pnnxinferpath = ptbase + "_pnnx_infer.py"; - std::string pnnxonnxpath = ptbase + ".pnnx.onnx"; - std::string ncnnparampath = ptbase + ".ncnn.param"; - std::string ncnnbinpath = ptbase + ".ncnn.bin"; - std::string ncnnpypath = ptbase + "_ncnn.py"; + // std::string pnnxinferpath = ptbase + "_pnnx_infer.py"; + // std::string pnnxonnxpath = ptbase + ".pnnx.onnx"; + // std::string ncnnparampath = ptbase + ".ncnn.param"; + // std::string ncnnbinpath = ptbase + ".ncnn.bin"; + // std::string ncnnpypath = ptbase + "_ncnn.py"; int fp16 = 1; int optlevel = 2; std::string device = "cpu"; @@ -231,8 +233,8 @@ int main(int argc, char** argv) std::string customop_infer_py = "None"; std::vector start_nodes; std::vector end_nodes; - - for (int i = 2; i < argc; i++) + std::string extract_model_name = "model"; + for (int i = 3; i < argc; i++) { // key=value char* kv = argv[i]; @@ -249,24 +251,24 @@ int main(int argc, char** argv) const char* key = kv; char* value = eqs + 1; - if (strcmp(key, "pnnxparam") == 0) - pnnxparampath = std::string(value); - if (strcmp(key, "pnnxbin") == 0) - pnnxbinpath = std::string(value); - if (strcmp(key, "pnnxpy") == 0) - pnnxpypath = std::string(value); - // add by senli[pnnx_infer] - if (strcmp(key, "pnnxinferpy") == 0) - pnnxinferpath = std::string(value); - - if (strcmp(key, "pnnxonnx") == 0) - pnnxonnxpath = std::string(value); - if (strcmp(key, "ncnnparam") == 0) - ncnnparampath = std::string(value); - if (strcmp(key, "ncnnbin") == 0) - ncnnbinpath = std::string(value); - if (strcmp(key, "ncnnpy") == 0) - ncnnpypath = std::string(value); + // if (strcmp(key, "pnnxparam") == 0) + // pnnxparampath = std::string(value); + // if (strcmp(key, "pnnxbin") == 0) + // pnnxbinpath = std::string(value); + // if (strcmp(key, "pnnxpy") == 0) + // pnnxpypath = std::string(value); + // // add by senli[pnnx_infer] + // if (strcmp(key, "pnnxinferpy") == 0) + // pnnxinferpath = std::string(value); + + // if (strcmp(key, "pnnxonnx") == 0) + // pnnxonnxpath = std::string(value); + // if (strcmp(key, "ncnnparam") == 0) + // ncnnparampath = std::string(value); + // if (strcmp(key, "ncnnbin") == 0) + // ncnnbinpath = std::string(value); + // if (strcmp(key, "ncnnpy") == 0) + // ncnnpypath = std::string(value); if (strcmp(key, "fp16") == 0) fp16 = atoi(value); if (strcmp(key, "optlevel") == 0) @@ -288,20 +290,25 @@ int main(int argc, char** argv) parse_string_list(value, start_nodes); if (strcmp(key, "end_nodes") == 0) parse_string_list(value, end_nodes); + if (strcmp(key, "extract_model_name") == 0) + extract_model_name = value; + } // print options { - fprintf(stderr, "pnnxparam = %s\n", pnnxparampath.c_str()); - fprintf(stderr, "pnnxbin = %s\n", pnnxbinpath.c_str()); - fprintf(stderr, "pnnxpy = %s\n", pnnxpypath.c_str()); - // add by senli[pnnx_infer] - fprintf(stderr, "pnnxinferpy = %s\n", pnnxinferpath.c_str()); - - fprintf(stderr, "pnnxonnx = %s\n", pnnxonnxpath.c_str()); - fprintf(stderr, "ncnnparam = %s\n", ncnnparampath.c_str()); - fprintf(stderr, "ncnnbin = %s\n", ncnnbinpath.c_str()); - fprintf(stderr, "ncnnpy = %s\n", ncnnpypath.c_str()); + // fprintf(stderr, "pnnxparam = %s\n", pnnxparampath.c_str()); + // fprintf(stderr, "pnnxbin = %s\n", pnnxbinpath.c_str()); + // fprintf(stderr, "pnnxpy = %s\n", pnnxpypath.c_str()); + // // add by senli[pnnx_infer] + // fprintf(stderr, "pnnxinferpy = %s\n", pnnxinferpath.c_str()); + + // fprintf(stderr, "pnnxonnx = %s\n", pnnxonnxpath.c_str()); + // fprintf(stderr, "ncnnparam = %s\n", ncnnparampath.c_str()); + // fprintf(stderr, "ncnnbin = %s\n", ncnnbinpath.c_str()); + // fprintf(stderr, "ncnnpy = %s\n", ncnnpypath.c_str()); + + fprintf(stderr, "save_dir = %s\n", save_dir.c_str()); fprintf(stderr, "fp16 = %d\n", fp16); fprintf(stderr, "optlevel = %d\n", optlevel); fprintf(stderr, "device = %s\n", device.c_str()); @@ -326,13 +333,20 @@ int main(int argc, char** argv) fprintf(stderr, "end_nodes = "); print_string_list(end_nodes); fprintf(stderr, "\n"); + fprintf(stderr, "extract_model_name = %s\n", extract_model_name.c_str()); + fprintf(stderr, "\n"); + } std::set foldable_constants; std::string foldable_constants_zippath = ptbase + ".foldable_constants.zip"; - pnnx::Graph pnnx_graph; - load_torchscript(ptpath, pnnx_graph, + if(input_shapes2.size() > 0) + { + dynamic_network = true; + } + std::unordered_map> pnnx_graph_map; + load_torchscript(ptpath, pnnx_graph_map, device, input_shapes, input_types, input_shapes2, input_types2, customop_modules, module_operators, @@ -341,76 +355,99 @@ int main(int argc, char** argv) // load_onnx(ptpath.c_str(), pnnx_graph); // g->dump(); + // #ifdef NDEBUG + // loop all graph tp pass + for (const auto& graph_pair : pnnx_graph_map) { + + std::string graph_name = graph_pair.first; + if(graph_name == "src") + graph_name = "model"; + std::string pnnxparampath = save_dir + "/" + graph_name + ".pnnx.param"; + std::string pnnxbinpath = save_dir + "/" + graph_name + ".pnnx.bin"; + std::string pnnxpypath = save_dir + "/" + graph_name + "_pnnx.py"; + std::string pnnxinferpath = save_dir + "/" + graph_name + "_pnnx_infer.py"; - fprintf(stderr, "############# pass_level2\n"); - - pnnx::pass_level2(pnnx_graph); - - pnnx_graph.save("debug.param", "debug.bin"); - // add by senli - std::set custom_ops; + fprintf(stderr, "pnnxparam = %s\n", pnnxparampath.c_str()); + fprintf(stderr, "pnnxbin = %s\n", pnnxbinpath.c_str()); + fprintf(stderr, "pnnxpy = %s\n", pnnxpypath.c_str()); + fprintf(stderr, "pnnxinferpy = %s\n", pnnxinferpath.c_str()); + + fprintf(stderr, "############# pass_level2 at %s\n", graph_name.c_str()); + pnnx::pass_level2(graph_pair.second); + + // pnnx_graph.save("debug.param", "debug.bin"); + // add by senli + std::set custom_ops; - if (optlevel >= 1) - { - fprintf(stderr, "############# pass_level3\n"); + if (optlevel >= 1) + { + fprintf(stderr, "############# pass_level3 at %s\n", graph_name.c_str()); - pnnx::pass_level3(pnnx_graph, foldable_constants, foldable_constants_zippath); + pnnx::pass_level3(graph_pair.second, foldable_constants, foldable_constants_zippath); - fprintf(stderr, "############# pass_level4\n"); + fprintf(stderr, "############# pass_level4 at %s\n", graph_name.c_str()); - // add by senli - pnnx::pass_level4(pnnx_graph, custom_ops); - } + // add by senli + pnnx::pass_level4(graph_pair.second, custom_ops); + } - pnnx_graph.save("debug2.param", "debug2.bin"); + // pnnx_graph.save("debug2.param", "debug2.bin"); - if (optlevel >= 2) - { - fprintf(stderr, "############# pass_level5\n"); + if (optlevel >= 2) + { + fprintf(stderr, "############# pass_level5 at %s\n", graph_name.c_str()); - pnnx::pass_level5(pnnx_graph, foldable_constants, foldable_constants_zippath); + pnnx::pass_level5(graph_pair.second, foldable_constants, foldable_constants_zippath); - // add by senli 20240321 - fprintf(stderr, "############# pass_level6\n"); + // add by senli 20240321 + fprintf(stderr, "############# pass_level6 at %s\n", graph_name.c_str()); - pnnx::pass_level6(pnnx_graph, foldable_constants, foldable_constants_zippath); - } + pnnx::pass_level6(graph_pair.second, foldable_constants, foldable_constants_zippath); + } - // delete foldable_constants_zippath - remove(foldable_constants_zippath.c_str()); - - // extract_sub_graph - int extract_flag = pnnx_graph.extract_sub_graph(start_nodes, end_nodes); - if(extract_flag == -1) - { - fprintf(stderr, "############# failed to extract_sub_graph\n"); - } + - pnnx_graph.save(pnnxparampath, pnnxbinpath); + // extract_sub_graph + if(extract_model_name == graph_name) + { + fprintf(stderr, "############# start to extract_sub_graph in %s\n", graph_name.c_str()); + int extract_flag = graph_pair.second->extract_sub_graph(start_nodes, end_nodes); + if(extract_flag == -1) + { + fprintf(stderr, "############# failed to extract_sub_graph\n"); + } + } + + graph_pair.second->save(pnnxparampath, pnnxbinpath); - pnnx_graph.python(pnnxpypath, pnnxbinpath); - //add by senli[pnnx_infer] - pnnx_graph.python_infer(pnnxinferpath, pnnxbinpath, customop_modules, custom_ops, customop_infer_py); + graph_pair.second->python(pnnxpypath, pnnxbinpath); + //add by senli[pnnx_infer] -#if BUILD_PNNX2ONNX - pnnx::save_onnx(pnnx_graph, pnnxonnxpath.c_str(), fp16); -#else - fprintf(stderr, "pnnx build without onnx-zero support, skip saving onnx\n"); -#endif + graph_pair.second->python_infer(pnnxinferpath, pnnxbinpath, customop_modules, custom_ops, customop_infer_py, save_dir); - // if (optlevel >= 2) - // { - // fprintf(stderr, "############# pass_ncnn\n"); + #if BUILD_PNNX2ONNX + pnnx::save_onnx(pnnx_graph, pnnxonnxpath.c_str(), fp16); + #else + fprintf(stderr, "pnnx build without onnx-zero support, skip saving onnx\n"); + #endif + // #endif + // if (optlevel >= 2) + // { + // fprintf(stderr, "############# pass_ncnn\n"); - // pnnx::pass_ncnn(pnnx_graph, module_operators); + // pnnx::pass_ncnn(pnnx_graph, module_operators); - // pnnx::save_ncnn(pnnx_graph, ncnnparampath, ncnnbinpath, ncnnpypath, fp16); - // } + // pnnx::save_ncnn(pnnx_graph, ncnnparampath, ncnnbinpath, ncnnpypath, fp16); + // } - // pnnx::Graph pnnx_graph2; + // pnnx::Graph pnnx_graph2; - // pnnx_graph2.load("pnnx.param", "pnnx.bin"); - // pnnx_graph2.save("pnnx2.param", "pnnx2.bin"); + // pnnx_graph2.load("pnnx.param", "pnnx.bin"); + // pnnx_graph2.save("pnnx2.param", "pnnx2.bin"); + } + // delete foldable_constants_zippath + remove(foldable_constants_zippath.c_str()); + return 0; } diff --git a/tools/pnnx/src/parse/pnnx_graph_parse.cpp b/tools/pnnx/src/parse/pnnx_graph_parse.cpp index b9f64e76f19..e1857f5c609 100644 --- a/tools/pnnx/src/parse/pnnx_graph_parse.cpp +++ b/tools/pnnx/src/parse/pnnx_graph_parse.cpp @@ -2,23 +2,23 @@ int main(int argc, char** argv); namespace pnnx_graph { -bool PnnxGraph::getNvpPnnxModel(const std::string& pt_path, const std::string& input_shape, const std::string& custom_op_path, - const std::string& custom_op_py, const std::string& start_nodes, const std::string& end_nodes) +bool PnnxGraph::getNvpPnnxModel(const std::string& pt_path, const std::string& save_dir, const std::string& input_shape, const std::string& custom_op_path, + const std::string& custom_op_py, const std::string& start_nodes, const std::string& end_nodes, const std::string& extract_model_name) { int argc; char** argv; if (custom_op_path != "None" && custom_op_py != "None") { - argc = 7; + argc = 9; } else if (custom_op_path != "None" && custom_op_py == "None") { - argc = 6; + argc = 8; } else if (custom_op_path == "None" && custom_op_py == "None") { - argc = 5; + argc = 7; } argv = new char*[argc]; @@ -30,36 +30,44 @@ bool PnnxGraph::getNvpPnnxModel(const std::string& pt_path, const std::string& i argv[1] = new char[pt_path.size() + 1]; std::strcpy(argv[1], pt_path.c_str()); + argv[2] = new char[save_dir.size() + 1]; + std::strcpy(argv[2], save_dir.c_str()); + //insert input_shape std::string input_shape_info = "inputshape=" + input_shape; - argv[2] = new char[input_shape_info.size() + 1]; - std::strcpy(argv[2], input_shape_info.c_str()); + argv[3] = new char[input_shape_info.size() + 1]; + std::strcpy(argv[3], input_shape_info.c_str()); if (custom_op_path != "None") { //insert custom_op std::string custom_op_info = "customop=" + custom_op_path; - argv[3] = new char[custom_op_info.size() + 1]; - std::strcpy(argv[3], custom_op_info.c_str()); + argv[4] = new char[custom_op_info.size() + 1]; + std::strcpy(argv[4], custom_op_info.c_str()); } if (custom_op_py != "None") { //insert custom_op_py std::string custom_op_py_info = "customop_infer_py=" + custom_op_py; - argv[4] = new char[custom_op_py_info.size() + 1]; - std::strcpy(argv[4], custom_op_py_info.c_str()); + argv[5] = new char[custom_op_py_info.size() + 1]; + std::strcpy(argv[5], custom_op_py_info.c_str()); } //insert start nodes std::string stard_nodes_info = "start_nodes=" + start_nodes; - argv[argc - 2] = new char[stard_nodes_info.size() + 1]; - std::strcpy( argv[argc - 2], stard_nodes_info.c_str()); + argv[argc - 3] = new char[stard_nodes_info.size() + 1]; + std::strcpy( argv[argc - 3], stard_nodes_info.c_str()); //insert end nodes std::string end_nodes_info = "end_nodes=" + end_nodes; - argv[argc - 1] = new char[end_nodes_info.size() + 1]; - std::strcpy( argv[argc - 1], end_nodes_info.c_str()); + argv[argc - 2] = new char[end_nodes_info.size() + 1]; + std::strcpy( argv[argc - 2], end_nodes_info.c_str()); + + //insert extract_model_name + std::string custom_op_py_info = "extract_model_name=" + custom_op_py; + argv[argc - 1] = new char[custom_op_py_info.size() + 1]; + std::strcpy(argv[argc - 2], custom_op_py_info.c_str()); int result = main(argc, argv); @@ -99,7 +107,6 @@ bool PnnxGraph::loadModel(const std::string& param_path, const std::string& bin_ return false; } - std::cout << "123" << bin_path << std::endl; //parse all operator std::vector operators_; std::vector operands_; diff --git a/tools/pnnx/src/parse/pnnx_graph_parse.h b/tools/pnnx/src/parse/pnnx_graph_parse.h index 55845467bb5..c05704a0398 100644 --- a/tools/pnnx/src/parse/pnnx_graph_parse.h +++ b/tools/pnnx/src/parse/pnnx_graph_parse.h @@ -24,11 +24,14 @@ class PnnxGraph * @return false */ bool getNvpPnnxModel(const std::string& pt_path, \ + const std::string& save_dir, \ const std::string& input_shape, \ const std::string& custom_op_path, \ const std::string& custom_op_py, const std::string& start_nodes = "", const std::string& end_nodes = ""); + + /** * @brief load pnnx graph diff --git a/tools/pnnx/src/parse/pnnx_ir_parse.cpp b/tools/pnnx/src/parse/pnnx_ir_parse.cpp index 7c0bfe8855a..86228b1a5d0 100644 --- a/tools/pnnx/src/parse/pnnx_ir_parse.cpp +++ b/tools/pnnx/src/parse/pnnx_ir_parse.cpp @@ -42,7 +42,6 @@ static size_t countSubstring(const std::string& str, const std::string& substr) return count; } - static bool type_is_integer(int type) { if (type == 1) return false; @@ -477,7 +476,7 @@ Attribute operator+(const Attribute& a, const Attribute& b) Parameter Parameter::parse_from_string(const std::string& value) { - if (value.find('%') != std::string::npos) + if (value.find('%') != std::string::npos) { Parameter p; p.type = 4; @@ -566,6 +565,7 @@ Parameter Parameter::parse_from_string(const std::string& value) p.i = std::stoi(value); return p; } + Graph::Graph() { } @@ -1024,7 +1024,6 @@ int Graph::save(const std::string& parampath, const std::string& binpath) return 0; } - int Graph::save_param(const std::string& parampath, const std::vector& input_operators, const std::vector& input_operands) { FILE* paramfp = fopen(parampath.c_str(), "wb"); diff --git a/tools/pnnx/src/parse/pnnx_ir_parse.h b/tools/pnnx/src/parse/pnnx_ir_parse.h index 9f2a4c721dc..2ac4b09a93b 100644 --- a/tools/pnnx/src/parse/pnnx_ir_parse.h +++ b/tools/pnnx/src/parse/pnnx_ir_parse.h @@ -211,7 +211,6 @@ class Graph int load(const std::string& parampath, const std::string& binpath); int save(const std::string& parampath, const std::string& binpath); - int save_param(const std::string& parampath, const std::vector& input_operators, const std::vector& input_operands); int python(const std::string& pypath, const std::string& binpath); diff --git a/tools/pnnx/src/pass_level1.cpp b/tools/pnnx/src/pass_level1.cpp index aa61fac007c..439022cdd86 100644 --- a/tools/pnnx/src/pass_level1.cpp +++ b/tools/pnnx/src/pass_level1.cpp @@ -11,7 +11,6 @@ // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. - #include #include @@ -50,15 +49,15 @@ FuseModulePassRegister::~FuseModulePassRegister() delete pass; } -static void fuse_moduleop_unpack(Graph& graph, const std::vector& module_operators) +static void fuse_moduleop_unpack(std::shared_ptr& graph, const std::vector& module_operators) { while (1) { bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (std::find(module_operators.begin(), module_operators.end(), op->type) == module_operators.end()) continue; @@ -88,7 +87,7 @@ static void fuse_moduleop_unpack(Graph& graph, const std::vector& m op2->inputs.clear(); op2->outputs.clear(); - graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op2)); + graph->ops.erase(std::find(graph->ops.begin(), graph->ops.end(), op2)); delete op2; @@ -100,8 +99,659 @@ static void fuse_moduleop_unpack(Graph& graph, const std::vector& m } } -void pass_level1(const torch::jit::Module& mod, const std::shared_ptr& g, const std::vector& module_operators, Graph& pg) + +void pass_level1_block(const torch::jit::Module& mod, Operator* src_op, torch::jit::Block* sub_block, std::shared_ptr& sub_pnnx_graph, \ +const std::vector& module_operators, std::unordered_map>& pnnx_graph_map, \ +int& pnnx_unknown_index, int& pnnx_loop_index) +{ + // create_input + int last_input_op_index = 0; + + for (int i = 1; i < (int)sub_block->inputs().size(); i++) + { + char input_name[32]; + sprintf(input_name, "pnnx_input_%d", i - 1); + last_input_op_index = i - 1; + const auto& block_input = sub_block->inputs()[i]; + // block_input->debugName() + Operator* op = sub_pnnx_graph->new_operator("pnnx.Input", input_name); + Operand* r = sub_pnnx_graph->new_operand(block_input); + Operand* src_r = src_op->inputs.at(i + 1); + r->params = src_r->params; + r->type = src_r->type; + r->shape = src_r->shape; + r->producer = op; + op->outputs.push_back(r); + } + std::map class_type_to_names; + for (const auto& n : sub_block->nodes()) + { + if (n->kind() == c10::prim::GetAttr) + { + // pass + std::string name = n->s(torch::jit::attr::name); + // std::string name = n->debugName(); + + auto class_type = n->output(0)->type()->cast(); + + if (class_type) + { + std::string class_type_str = class_type->str(); + class_type_to_names[class_type_str] = name; + // class_type_to_names[class_type_str] = class_type_str + "." + name; + } + else + { + // Tensor from some class + // Operator* op = pg->new_operator(n->kind().toDisplayString(), name); + Operator* op = sub_pnnx_graph->new_operator("pnnx.Attribute", name); + + for (int i = 0; i < (int)n->outputs().size(); i++) + { + const auto& on = n->output(i); + Operand* r = sub_pnnx_graph->new_operand(on); + r->producer = op; + op->outputs.push_back(r); + } + + std::deque module_names; // = split(n->input(0)->node()->s(torch::jit::attr::name), '.'); + { + auto np = n->input(0)->node(); + while (np->hasAttribute(torch::jit::attr::name)) + { + module_names.push_front(np->s(torch::jit::attr::name)); + np = np->input(0)->node(); + } + } + + std::string wrapped_name; + auto sub_mod = mod; + for (auto module_name : module_names) + { + if (wrapped_name.size() > 0) + wrapped_name = wrapped_name + "." + module_name; + else + wrapped_name = module_name; + sub_mod = sub_mod.attr(module_name).toModule(); + } + + if (wrapped_name.empty()) + { + // top-level module + wrapped_name = name; + } + + op->name = wrapped_name; + + // op->params["this"] = n->input(i) + + // sub_mod.dump(true, true, true); + + op->attrs["data"] = sub_mod.attr(name).toTensor(); + op->outputs[0]->type = op->attrs["data"].type; + op->outputs[0]->shape = op->attrs["data"].shape; + } + } + else if (n->kind() == c10::prim::Constant) // || n->kind() == c10::prim::ListConstruct) + { + char name[32]; + sprintf(name, "pnnx_%d", pnnx_unknown_index++); + + Operator* op = sub_pnnx_graph->new_operator(n->kind().toDisplayString(), name); + + for (int i = 0; i < (int)n->inputs().size(); i++) + { + const auto& in = n->input(i); + Operand* r = sub_pnnx_graph->get_operand(in->debugName()); + if (r == 0) + { + Operand* r1 = 0; + for (const auto& graph_pair : pnnx_graph_map) { + r1 = graph_pair.second->get_operand(in->debugName()); + if(r1 != 0) + break; + } + + assert(r1 != 0 && "cur tensor name : %s not in graph\n",in->debugName().c_str()); + + // if(r1 == 0) + // { + // throw std::exception("cur tensor name : %s not in graph\n",in->debugName().c_str()); + // } + std::string last_input_op_name = "pnnx_input_" + std::to_string(last_input_op_index); + Operator* last_input_op = sub_pnnx_graph->get_operator(last_input_op_name); + assert(last_input_op != 0 && "failed to find last input op : %s\n",last_input_op_name.c_str()); + if(r1->producer->type == "prim::Constant") + { + Operator* constant_op = r1->producer; + // insert type of prim::Constant new input to sub_graph + Operator* new_input_op = sub_pnnx_graph->new_operator_after(constant_op->type, constant_op->name, last_input_op); + new_input_op->inputnames = constant_op->inputnames; + new_input_op->params = constant_op->params; + new_input_op->attrs = constant_op->attrs; + Operand* r2 = sub_pnnx_graph->new_operand(in->debugName()); + r2->producer = new_input_op; + r2->consumers.push_back(op); + r2->params = r1->params; + r2->type = r1->type; + r2->shape = r1->shape; + new_input_op->outputs.push_back(r2); + op->inputs.push_back(r2); + } + else + { + // insert new input to loop + r1->consumers.push_back(src_op); + src_op->inputs.push_back(r1); + // insert new input to sub_graph + last_input_op_index++; + std::string new_input_op_name = "pnnx_input_" + std::to_string(last_input_op_index); + Operator* new_input_op = sub_pnnx_graph->new_operator_after("pnnx.Input", new_input_op_name, last_input_op); + Operand* r2 = sub_pnnx_graph->new_operand(in->debugName()); + r2->producer = new_input_op; + r2->consumers.push_back(op); + r2->params = r1->params; + r2->type = r1->type; + r2->shape = r1->shape; + new_input_op->outputs.push_back(r2); + op->inputs.push_back(r2); + } + + } + else{ + r->consumers.push_back(op); + op->inputs.push_back(r); + } + + + } + + for (int i = 0; i < (int)n->outputs().size(); i++) + { + const auto& on = n->output(i); + Operand* r = sub_pnnx_graph->new_operand(on); + r->producer = op; + op->outputs.push_back(r); + } + + op->params["value"] = n; + + if (op->params["value"].type == 8) + { + op->type = "pnnx.Attribute"; + + op->params.erase("value"); + + op->attrs["data"] = n->t(torch::jit::attr::value); + } + } + else if (n->kind() == c10::prim::CallMethod) + { + auto class_type = n->input(0)->type()->cast(); + // const std::string& name = n->s(torch::jit::attr::name); + + // fprintf(stderr, "call %s\n", class_type->str().c_str()); + + std::string name = class_type_to_names[class_type->str()]; + + std::string class_type_str = torch::jit::removeTorchMangle(class_type->str()); + + std::string class_type_str_no_torch_prefix = class_type_str.substr(10); + + std::string optypename = class_type_str; + + for (const auto& ow : get_global_pnnx_fuse_module_passes()) + { + if (class_type_str != ow->match_type_str()) + continue; + + optypename = ow->type_str(); + break; + } + + if (optypename == class_type_str) + { + optypename = class_type_str_no_torch_prefix; + } + + Operator* op = sub_pnnx_graph->new_operator(optypename, name); + + for (int i = 1; i < (int)n->inputs().size(); i++) + { + const auto& in = n->input(i); + Operand* r = sub_pnnx_graph->get_operand(in->debugName()); + if (r == 0) + { + Operand* r1 = 0; + for (const auto& graph_pair : pnnx_graph_map) { + r1 = graph_pair.second->get_operand(in->debugName()); + if(r1 != 0) + break; + } + + assert(r1 != 0 && "cur tensor name : %s not in graph\n",in->debugName().c_str()); + // if(r1 == 0) + // { + // throw std::exception("cur tensor name : %s not in graph\n",in->debugName().c_str()); + // } + + std::string last_input_op_name = "pnnx_input_" + std::to_string(last_input_op_index); + Operator* last_input_op = sub_pnnx_graph->get_operator(last_input_op_name); + assert(last_input_op != 0 && "failed to find last input op : %s\n",last_input_op_name.c_str()); + if(r1->producer->type == "prim::Constant") + { + Operator* constant_op = r1->producer; + // insert type of prim::Constant new input to sub_graph + Operator* new_input_op = sub_pnnx_graph->new_operator_after(constant_op->type, constant_op->name, last_input_op); + new_input_op->inputnames = constant_op->inputnames; + new_input_op->params = constant_op->params; + new_input_op->attrs = constant_op->attrs; + Operand* r2 = sub_pnnx_graph->new_operand(in->debugName()); + r2->producer = new_input_op; + r2->consumers.push_back(op); + r2->params = r1->params; + r2->type = r1->type; + r2->shape = r1->shape; + new_input_op->outputs.push_back(r2); + op->inputs.push_back(r2); + } + else + { + // insert new input to loop + r1->consumers.push_back(src_op); + src_op->inputs.push_back(r1); + // insert new input to sub_graph + last_input_op_index++; + std::string new_input_op_name = "pnnx_input_" + std::to_string(last_input_op_index); + Operator* new_input_op = sub_pnnx_graph->new_operator_after("pnnx.Input", new_input_op_name, last_input_op); + Operand* r2 = sub_pnnx_graph->new_operand(in->debugName()); + r2->producer = new_input_op; + r2->consumers.push_back(op); + r2->params = r1->params; + r2->type = r1->type; + r2->shape = r1->shape; + new_input_op->outputs.push_back(r2); + op->inputs.push_back(r2); + } + + } + else{ + r->consumers.push_back(op); + op->inputs.push_back(r); + } + } + + for (int i = 0; i < (int)n->outputs().size(); i++) + { + const auto& on = n->output(i); + Operand* r = sub_pnnx_graph->new_operand(on); + r->producer = op; + op->outputs.push_back(r); + } + + // module operator + if (std::find(module_operators.begin(), module_operators.end(), class_type_str_no_torch_prefix) != module_operators.end()) + { + const std::string& function_name = n->s(torch::jit::attr::name); + torch::jit::Function& function = class_type->getMethod(function_name); + if (function.isGraphFunction()) + { +#if TORCH_VERSION_MAJOR >= 2 || (TORCH_VERSION_MAJOR >= 1 && TORCH_VERSION_MINOR >= 11) + torch::jit::Block* moduleop_block = toGraphFunction(function).graph()->block(); +#else + torch::jit::Block* moduleop_block = function.graph()->block(); +#endif + + std::map constant_attr_nodes; + for (const auto& mn : moduleop_block->nodes()) + { + if (mn->kind() == c10::prim::GetAttr) + { + std::string name = mn->s(torch::jit::attr::name); + // std::string name = mn->debugName(); + + auto class_type = mn->output(0)->type()->cast(); + + if (!class_type) + { + std::deque module_names; // = split(mn->input(0)->node()->s(torch::jit::attr::name), '.'); + { + auto np = n->input(0)->node(); + while (np->hasAttribute(torch::jit::attr::name)) + { + module_names.push_front(np->s(torch::jit::attr::name)); + np = np->input(0)->node(); + } + } + std::deque module_names2; + { + auto np = mn->input(0)->node(); + while (np->hasAttribute(torch::jit::attr::name)) + { + module_names2.push_front(np->s(torch::jit::attr::name)); + np = np->input(0)->node(); + } + } + for (auto x : module_names2) + { + module_names.push_back(x); + } + + auto sub_mod = mod; + for (auto module_name : module_names) + { + sub_mod = sub_mod.attr(module_name).toModule(); + } + + std::string wrapped_name; + for (auto module_name : module_names2) + { + if (wrapped_name.size() > 0) + wrapped_name = wrapped_name + "." + module_name; + else + wrapped_name = module_name; + } + + if (wrapped_name.empty()) + { + // top-level module + wrapped_name = name; + } + else + { + wrapped_name = wrapped_name + "." + name; + } + + op->attrs[wrapped_name] = sub_mod.attr(name).toTensor(); + } + } + else if (mn->kind() == c10::prim::Constant) + { + Parameter p(mn); + + if (p.type == 8) + { + size_t unique_id = mn->output(0)->unique(); + constant_attr_nodes[unique_id] = mn; + } + } + } + + int pnnx_moduleop_unknown_index = 0; + for (auto attr : constant_attr_nodes) + { + char name[32]; + sprintf(name, "pnnx_%02d", pnnx_moduleop_unknown_index); + op->attrs[name] = attr.second->t(torch::jit::attr::value); + pnnx_moduleop_unknown_index++; + } + } + } + else + { + for (const auto& ow : get_global_pnnx_fuse_module_passes()) + { + if (class_type_str != ow->match_type_str()) + continue; + + auto class_type = n->input(0)->type()->cast(); + torch::jit::Function& function = class_type->getMethod(n->s(torch::jit::attr::name)); + + std::deque module_names; // = split(n->input(0)->node()->s(torch::jit::attr::name), '.'); + { + auto np = n->input(0)->node(); + while (np->hasAttribute(torch::jit::attr::name)) + { + module_names.push_front(np->s(torch::jit::attr::name)); + np = np->input(0)->node(); + } + } + + std::string wrapped_name; + auto sub_mod = mod; + for (auto module_name : module_names) + { + if (wrapped_name.size() > 0) + wrapped_name = wrapped_name + "." + module_name; + else + wrapped_name = module_name; + sub_mod = sub_mod.attr(module_name).toModule(); + } + + op->name = wrapped_name; + +#if TORCH_VERSION_MAJOR >= 2 || (TORCH_VERSION_MAJOR >= 1 && TORCH_VERSION_MINOR >= 11) + ow->write(op, toGraphFunction(function).graph(), sub_mod); +#else + ow->write(op, function.graph(), sub_mod); +#endif + + break; + } + } + } + // else if (n->kind() == c10::prim::CallFunction) + // { + // fprintf(stderr, "function %s", n->kind().toDisplayString()); + // + // AT_ASSERT(cur->input(0)->node()->kind() == c10::prim::Constant); + // auto function_constant = cur->input(0)->node(); + // auto fun_type = function_constant->output()->type()->expect(); + // if (!fun_type->function()->isGraphFunction()) + // { + // continue; + // } + // cur->removeInput(0); + // + // fprintf(stderr, "inline function %s\n", fun_type->function()->name().c_str()); + // + // GRAPH_UPDATE("Inlining function '", fun_type->function()->name(), "' to ", *cur); + // GRAPH_UPDATE("Function body: ", *fun_type->function()->optimized_graph()); + // inlineCallTo(cur, fun_type->function(), false); + // break; + // } + else if(n->kind() == c10::prim::Loop) + { + char loop_op_name[32]; + sprintf(loop_op_name, "pnnx_loop_%d", pnnx_loop_index++); + + Operator* op = sub_pnnx_graph->new_operator(n->kind().toDisplayString(), loop_op_name); + + for (int i = 0; i < (int)n->inputs().size(); i++) + { + const auto& in = n->input(i); + Operand* r =sub_pnnx_graph->get_operand(in->debugName()); + if (r == 0) + { + Operand* r1 = 0; + for (const auto& graph_pair : pnnx_graph_map) { + r1 = graph_pair.second->get_operand(in->debugName()); + if(r1 != 0) + break; + } + + assert(r1 != 0 && "cur tensor name : %s not in graph\n",in->debugName().c_str()); + // if(r1 == 0) + // { + // throw std::exception("cur tensor name : %s not in graph\n",in->debugName().c_str()); + // } + + std::string last_input_op_name = "pnnx_input_" + std::to_string(last_input_op_index); + Operator* last_input_op = sub_pnnx_graph->get_operator(last_input_op_name); + assert(last_input_op != 0 && "failed to find last input op : %s\n",last_input_op_name.c_str()); + if(r1->producer->type == "prim::Constant") + { + Operator* constant_op = r1->producer; + // insert type of prim::Constant new input to sub_graph + Operator* new_input_op = sub_pnnx_graph->new_operator_after(constant_op->type, constant_op->name, last_input_op); + new_input_op->inputnames = constant_op->inputnames; + new_input_op->params = constant_op->params; + new_input_op->attrs = constant_op->attrs; + Operand* r2 = sub_pnnx_graph->new_operand(in->debugName()); + r2->producer = new_input_op; + r2->consumers.push_back(op); + r2->params = r1->params; + r2->type = r1->type; + r2->shape = r1->shape; + new_input_op->outputs.push_back(r2); + op->inputs.push_back(r2); + } + else + { + // insert new input to loop + r1->consumers.push_back(src_op); + src_op->inputs.push_back(r1); + // insert new input to sub_graph + last_input_op_index++; + std::string new_input_op_name = "pnnx_input_" + std::to_string(last_input_op_index); + Operator* new_input_op = sub_pnnx_graph->new_operator_after("pnnx.Input", new_input_op_name, last_input_op); + Operand* r2 = sub_pnnx_graph->new_operand(in->debugName()); + r2->producer = new_input_op; + r2->consumers.push_back(op); + r2->params = r1->params; + r2->type = r1->type; + r2->shape = r1->shape; + new_input_op->outputs.push_back(r2); + op->inputs.push_back(r2); + } + + } + else{ + r->consumers.push_back(op); + op->inputs.push_back(r); + } + } + + for (int i = 0; i < (int)n->outputs().size(); i++) + { + const auto& on = n->output(i); + Operand* r = sub_pnnx_graph->new_operand(on); + r->producer = op; + op->outputs.push_back(r); + } + std::shared_ptr sub_pnnx_graph2 = std::make_shared(); + int block_num = 0; + for (torch::jit::Block* subBlock2 : n->blocks()) + { + + assert(block_num == 0 && "block num > 1 in loop"); + pass_level1_block(mod, op, subBlock2, sub_pnnx_graph2, module_operators, pnnx_graph_map, pnnx_unknown_index, pnnx_loop_index); + block_num++; + } + pnnx_graph_map[std::string(loop_op_name)] = sub_pnnx_graph2; + + } + else + { + char name[32]; + sprintf(name, "pnnx_%d", pnnx_unknown_index++); + + Operator* op = sub_pnnx_graph->new_operator(n->kind().toDisplayString(), name); + + for (int i = 0; i < (int)n->inputs().size(); i++) + { + const auto& in = n->input(i); + Operand* r = sub_pnnx_graph->get_operand(in->debugName()); + if (r == 0) + { + Operand* r1 = 0; + for (const auto& graph_pair : pnnx_graph_map) { + r1 = graph_pair.second->get_operand(in->debugName()); + if(r1 != 0) + break; + } + + assert(r1 != 0 && "cur tensor name : %s not in graph\n",in->debugName().c_str()); + // if(r1 == 0) + // { + // throw std::exception("cur tensor name : %s not in graph\n",in->debugName().c_str()); + // } + + std::string last_input_op_name = "pnnx_input_" + std::to_string(last_input_op_index); + Operator* last_input_op = sub_pnnx_graph->get_operator(last_input_op_name); + assert(last_input_op != 0 && "failed to find last input op : %s\n",last_input_op_name.c_str()); + if(r1->producer->type == "prim::Constant") + { + Operator* constant_op = r1->producer; + // insert type of prim::Constant new input to sub_graph + Operator* new_input_op = sub_pnnx_graph->new_operator_after(constant_op->type, constant_op->name, last_input_op); + new_input_op->inputnames = constant_op->inputnames; + new_input_op->params = constant_op->params; + new_input_op->attrs = constant_op->attrs; + Operand* r2 = sub_pnnx_graph->new_operand(in->debugName()); + r2->producer = new_input_op; + r2->consumers.push_back(op); + r2->params = r1->params; + r2->type = r1->type; + r2->shape = r1->shape; + new_input_op->outputs.push_back(r2); + op->inputs.push_back(r2); + } + else + { + // insert new input to loop + r1->consumers.push_back(src_op); + src_op->inputs.push_back(r1); + // insert new input to sub_graph + last_input_op_index++; + std::string new_input_op_name = "pnnx_input_" + std::to_string(last_input_op_index); + Operator* new_input_op = sub_pnnx_graph->new_operator_after("pnnx.Input", new_input_op_name, last_input_op); + Operand* r2 = sub_pnnx_graph->new_operand(in->debugName()); + r2->producer = new_input_op; + r2->consumers.push_back(op); + r2->params = r1->params; + r2->type = r1->type; + r2->shape = r1->shape; + new_input_op->outputs.push_back(r2); + op->inputs.push_back(r2); + } + + } + else{ + r->consumers.push_back(op); + op->inputs.push_back(r); + } + } + + for (int i = 0; i < (int)n->outputs().size(); i++) + { + const auto& on = n->output(i); + Operand* r = sub_pnnx_graph->new_operand(on); + r->producer = op; + op->outputs.push_back(r); + } + + } + } + + for (int i = 1; i < (int)sub_block->outputs().size(); i++) + { + const auto& out = sub_block->outputs()[i]; + char output_name[32]; + sprintf(output_name, "pnnx_output_%d", i - 1); + // out->debugName() + Operator* op = sub_pnnx_graph->new_operator("pnnx.Output", output_name); + Operand* r = sub_pnnx_graph->get_operand(out->debugName()); + Operand* src_r = src_op->outputs.at(i - 1); + r->params = src_r->params; + r->type = src_r->type; + r->shape = src_r->shape; + r->consumers.push_back(op); + op->inputs.push_back(r); + } + // post process + fuse_moduleop_unpack(sub_pnnx_graph, module_operators); + + +} + +void pass_level1(const torch::jit::Module& mod,\ + const std::shared_ptr& g, \ + const std::vector& module_operators,\ + std::unordered_map>& pnnx_graph_map) { + std::shared_ptr pg = std::make_shared(); + pnnx_graph_map["src"] = pg; for (int i = 1; i < (int)g->inputs().size(); i++) { const auto& in = g->inputs()[i]; @@ -109,15 +759,15 @@ void pass_level1(const torch::jit::Module& mod, const std::shared_ptrnew_operator("pnnx.Input", name); + Operand* r = pg->new_operand(in); r->producer = op; op->outputs.push_back(r); } std::map class_type_to_names; int pnnx_unknown_index = 0; - + int pnnx_loop_index = 0; for (const auto& n : g->block()->nodes()) { if (n->kind() == c10::prim::GetAttr) @@ -137,13 +787,13 @@ void pass_level1(const torch::jit::Module& mod, const std::shared_ptrkind().toDisplayString(), name); - Operator* op = pg.new_operator("pnnx.Attribute", name); + // Operator* op = pg->new_operator(n->kind().toDisplayString(), name); + Operator* op = pg->new_operator("pnnx.Attribute", name); for (int i = 0; i < (int)n->outputs().size(); i++) { const auto& on = n->output(i); - Operand* r = pg.new_operand(on); + Operand* r = pg->new_operand(on); r->producer = op; op->outputs.push_back(r); } @@ -191,12 +841,12 @@ void pass_level1(const torch::jit::Module& mod, const std::shared_ptrkind().toDisplayString(), name); + Operator* op = pg->new_operator(n->kind().toDisplayString(), name); for (int i = 0; i < (int)n->inputs().size(); i++) { const auto& in = n->input(i); - Operand* r = pg.get_operand(in->debugName()); + Operand* r = pg->get_operand(in->debugName()); r->consumers.push_back(op); op->inputs.push_back(r); } @@ -204,7 +854,7 @@ void pass_level1(const torch::jit::Module& mod, const std::shared_ptroutputs().size(); i++) { const auto& on = n->output(i); - Operand* r = pg.new_operand(on); + Operand* r = pg->new_operand(on); r->producer = op; op->outputs.push_back(r); } @@ -249,12 +899,12 @@ void pass_level1(const torch::jit::Module& mod, const std::shared_ptrnew_operator(optypename, name); for (int i = 1; i < (int)n->inputs().size(); i++) { const auto& in = n->input(i); - Operand* r = pg.get_operand(in->debugName()); + Operand* r = pg->get_operand(in->debugName()); r->consumers.push_back(op); op->inputs.push_back(r); } @@ -262,7 +912,7 @@ void pass_level1(const torch::jit::Module& mod, const std::shared_ptroutputs().size(); i++) { const auto& on = n->output(i); - Operand* r = pg.new_operand(on); + Operand* r = pg->new_operand(on); r->producer = op; op->outputs.push_back(r); } @@ -428,17 +1078,52 @@ void pass_level1(const torch::jit::Module& mod, const std::shared_ptrfunction(), false); // break; // } + else if(n->kind() == c10::prim::Loop) + { + char loop_op_name[32]; + sprintf(loop_op_name, "pnnx_loop_%d", pnnx_loop_index++); + + Operator* op = pg->new_operator(n->kind().toDisplayString(), loop_op_name); + + for (int i = 0; i < (int)n->inputs().size(); i++) + { + const auto& in = n->input(i); + Operand* r = pg->get_operand(in->debugName()); + r->consumers.push_back(op); + op->inputs.push_back(r); + } + + for (int i = 0; i < (int)n->outputs().size(); i++) + { + const auto& on = n->output(i); + Operand* r = pg->new_operand(on); + r->producer = op; + op->outputs.push_back(r); + } + + std::shared_ptr sub_pnnx_graph = std::make_shared(); + int block_num = 0; + for (torch::jit::Block* subBlock : n->blocks()) + { + + assert(block_num == 0 && "block num > 1 in loop"); + pass_level1_block(mod, op, subBlock, sub_pnnx_graph, module_operators, pnnx_graph_map, pnnx_unknown_index, pnnx_loop_index); + block_num++; + } + pnnx_graph_map[std::string(loop_op_name)] = sub_pnnx_graph; + + } else { char name[32]; sprintf(name, "pnnx_%d", pnnx_unknown_index++); - Operator* op = pg.new_operator(n->kind().toDisplayString(), name); + Operator* op = pg->new_operator(n->kind().toDisplayString(), name); for (int i = 0; i < (int)n->inputs().size(); i++) { const auto& in = n->input(i); - Operand* r = pg.get_operand(in->debugName()); + Operand* r = pg->get_operand(in->debugName()); r->consumers.push_back(op); op->inputs.push_back(r); } @@ -446,10 +1131,11 @@ void pass_level1(const torch::jit::Module& mod, const std::shared_ptroutputs().size(); i++) { const auto& on = n->output(i); - Operand* r = pg.new_operand(on); + Operand* r = pg->new_operand(on); r->producer = op; op->outputs.push_back(r); } + } } @@ -459,8 +1145,8 @@ void pass_level1(const torch::jit::Module& mod, const std::shared_ptrdebugName()); + Operator* op = pg->new_operator("pnnx.Output", name); + Operand* r = pg->get_operand(in->debugName()); r->consumers.push_back(op); op->inputs.push_back(r); } diff --git a/tools/pnnx/src/pass_level1.h b/tools/pnnx/src/pass_level1.h index 1eb5a7ab9af..8f54dee359f 100644 --- a/tools/pnnx/src/pass_level1.h +++ b/tools/pnnx/src/pass_level1.h @@ -14,7 +14,8 @@ #ifndef PNNX_PASS_LEVEL1_H #define PNNX_PASS_LEVEL1_H - +#include +#include #include #include #include "ir.h" @@ -48,7 +49,9 @@ const std::vector& get_global_pnnx_fuse_module_passes(); #define REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(CLASS) \ static FuseModulePassRegister g_global_pnnx_fusemodulepass_##CLASS##_register(new CLASS); -void pass_level1(const torch::jit::Module& mod, const std::shared_ptr& g, const std::vector& module_operators, Graph& pg); +void pass_level1(const torch::jit::Module& mod, const std::shared_ptr& g,\ + const std::vector& module_operators, \ + std::unordered_map>& pnnx_graph_map); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_Conv2d.cpp b/tools/pnnx/src/pass_level1/nn_Conv2d.cpp index 4907aab668e..47dd478e016 100644 --- a/tools/pnnx/src/pass_level1/nn_Conv2d.cpp +++ b/tools/pnnx/src/pass_level1/nn_Conv2d.cpp @@ -55,7 +55,6 @@ class Conv2d : public FuseModulePass { convolution = convolution_mode; } - if(!convolution) { return; @@ -72,8 +71,6 @@ class Conv2d : public FuseModulePass fprintf(stderr, "Caught an unknown exception\n"); } - - op->params["groups"] = convolution->namedInput("groups"); op->params["in_channels"] = weight.size(1) * op->params["groups"].i; op->params["out_channels"] = weight.size(0); op->params["kernel_size"] = Parameter{weight.size(2), weight.size(3)}; diff --git a/tools/pnnx/src/pass_level2.cpp b/tools/pnnx/src/pass_level2.cpp index d8feb795812..57469b5c895 100644 --- a/tools/pnnx/src/pass_level2.cpp +++ b/tools/pnnx/src/pass_level2.cpp @@ -780,7 +780,7 @@ static bool match(const Operator* anchor, const Operator* pattern, std::map graph, const GraphRewriterPass* pass, int& opindex) { Graph pattern_graph; pattern_graph.parse(pass->match_pattern_graph()); @@ -791,6 +791,7 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde std::vector pattern_graph_output_operators; for (const auto& x : pattern_graph.ops) { + // printf("op_name = %s, op_type = %s, pass_name = %s\n", x->name.c_str(), x->type.c_str(), pass->type_str()); if (x->type == "pnnx.Input") { for (const auto& y : x->outputs) @@ -808,7 +809,7 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde while (1) { - const int graph_op_count = (int)graph.ops.size(); + const int graph_op_count = (int)graph->ops.size(); bool matched = true; @@ -834,7 +835,7 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde int j = q; for (; j >= 0; j--) { - const Operator* anchor = graph.ops[j]; + const Operator* anchor = graph->ops[j]; std::map matched_operators2; std::map matched_inputs2; @@ -966,7 +967,7 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde for (auto& _x : operands_to_remove) { Operand* r = _x.second; - graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), r)); + graph->operands.erase(std::find(graph->operands.begin(), graph->operands.end(), r)); delete r; } @@ -976,12 +977,12 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde int cur_index = 1; for (auto& o : matched_operators) { - int c_index = std::find(graph.ops.begin(), graph.ops.end(), o.second) - graph.ops.begin(); + int c_index = std::find(graph->ops.begin(), graph->ops.end(), o.second) - graph->ops.begin(); cur_index = std::max(cur_index, c_index + 1); } - cur_index = std::min(cur_index, (int)graph.ops.size() - 1); - cur = graph.ops[cur_index]; + cur_index = std::min(cur_index, (int)graph->ops.size() - 1); + cur = graph->ops[cur_index]; } // remove all matched_operators @@ -991,7 +992,7 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde Operator* x = (Operator*)_x.second; - graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), x)); + graph->ops.erase(std::find(graph->ops.begin(), graph->ops.end(), x)); delete _x.second; } @@ -999,7 +1000,7 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde if (pass->replace_pattern_graph() == 0) { // insert single - Operator* op = graph.new_operator_before(pass->type_str(), std::string(pass->name_str()), cur); + Operator* op = graph->new_operator_before(pass->type_str(), std::string(pass->name_str()), cur); for (const auto& k : pattern_graph_inputs) { @@ -1035,7 +1036,7 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde if (op->type == "pnnx.Input" || op->type == "pnnx.Output") continue; - graph.ops.insert(std::find(graph.ops.begin(), graph.ops.end(), cur), op); + graph->ops.insert(std::find(graph->ops.begin(), graph->ops.end(), cur), op); replace_graph.ops[i] = 0; ops[op->name] = op; } @@ -1046,7 +1047,7 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde if (r->producer->type == "pnnx.Input" || (r->consumers.size() == 1 && r->consumers[0]->type == "pnnx.Output")) continue; - graph.operands.push_back(r); + graph->operands.push_back(r); replace_graph.operands[i] = 0; } @@ -1121,9 +1122,9 @@ static bool is_alias_op(const Operator* op) return false; } -static void functionize(Graph& graph) +static void functionize(std::shared_ptr graph) { - // graph.save("0.param", "0.bin"); + // graph->save("0.param", "0.bin"); // 1. create shadow view/slice/select/... for each consumer // 2. replace inplace op, append copy @@ -1137,9 +1138,9 @@ static void functionize(Graph& graph) // 1. create shadow view/slice/select/... for each consumer { - for (int i = (int)graph.ops.size() - 1; i >= 0; i--) + for (int i = (int)graph->ops.size() - 1; i >= 0; i--) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (!is_alias_op(op)) continue; @@ -1153,9 +1154,9 @@ static void functionize(Graph& graph) { Operator* op1 = out0->consumers[j]; - Operator* op_shadow = graph.new_operator_after(op->type, op->name + "_pnnxshadow_" + std::to_string(j), op); + Operator* op_shadow = graph->new_operator_after(op->type, op->name + "_pnnxshadow_" + std::to_string(j), op); - Operand* shadow_out = graph.new_operand(op_shadow->name + "_out"); + Operand* shadow_out = graph->new_operand(op_shadow->name + "_out"); op_shadow->inputs = op->inputs; op_shadow->params = op->params; @@ -1184,14 +1185,14 @@ static void functionize(Graph& graph) } } - // graph.save("1.param", "1.bin"); + // graph->save("1.param", "1.bin"); // 2. replace inplace op, append copy // 3. tag operand alias for view/slice/select/... output { - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; bool is_inplace_op = op->type.size() > 2 && op->type[op->type.size() - 2] != '_' && op->type[op->type.size() - 1] == '_'; @@ -1207,17 +1208,17 @@ static void functionize(Graph& graph) } else { - alias_index = std::find(graph.operands.begin(), graph.operands.end(), in) - graph.operands.begin(); + alias_index = std::find(graph->operands.begin(), graph->operands.end(), in) - graph->operands.begin(); } if (op->type == "aten::copy_") { op->outputs[0]->params["__alias__"] = alias_index; - // fprintf(stderr, "operand %s is alias of %s\n", op->outputs[0]->name.c_str(), graph.operands[alias_index]->name.c_str()); + // fprintf(stderr, "operand %s is alias of %s\n", op->outputs[0]->name.c_str(), graph->operands[alias_index]->name.c_str()); // set copy output shape as the alias one - op->outputs[0]->type = graph.operands[alias_index]->type; - op->outputs[0]->shape = graph.operands[alias_index]->shape; + op->outputs[0]->type = graph->operands[alias_index]->type; + op->outputs[0]->shape = graph->operands[alias_index]->shape; continue; } @@ -1225,7 +1226,7 @@ static void functionize(Graph& graph) if (is_alias_op(op)) { op->outputs[0]->params["__alias__"] = alias_index; - // fprintf(stderr, "operand %s is alias of %s\n", op->outputs[0]->name.c_str(), graph.operands[alias_index]->name.c_str()); + // fprintf(stderr, "operand %s is alias of %s\n", op->outputs[0]->name.c_str(), graph->operands[alias_index]->name.c_str()); continue; } @@ -1235,13 +1236,13 @@ static void functionize(Graph& graph) op->type = op->type.substr(0, op->type.size() - 1); // append aten::copy_ - if (graph.operands[alias_index]->consumers.size() > 1) + if (graph->operands[alias_index]->consumers.size() > 1) { Operand* in0 = op->inputs[0]; Operand* out0 = op->outputs[0]; - Operator* op_copy = graph.new_operator_after("aten::copy_", op->name + "_copy", op); - Operand* copy_out = graph.new_operand(op->name + "_copy_out"); + Operator* op_copy = graph->new_operator_after("aten::copy_", op->name + "_copy", op); + Operand* copy_out = graph->new_operand(op->name + "_copy_out"); op_copy->inputs.push_back(in0); op_copy->inputs.push_back(out0); @@ -1255,13 +1256,13 @@ static void functionize(Graph& graph) } } - // graph.save("3.param", "3.bin"); + // graph->save("3.param", "3.bin"); // 4. scan inplace copy op, collect affacted alias { - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "aten::copy_") continue; @@ -1272,16 +1273,16 @@ static void functionize(Graph& graph) // inplace op output always alias with the input const int alias_index = out0->params.at("__alias__").i; - Operand* alias_in0 = graph.operands[alias_index]; + Operand* alias_in0 = graph->operands[alias_index]; // fprintf(stderr, "\n---> %s for %s\n", op->name.c_str(), alias_in0->name.c_str()); size_t i_advanced = 0; // 5. look for any op after the inplace op with alias input - for (size_t j = i + 1; j < graph.ops.size(); j++) + for (size_t j = i + 1; j < graph->ops.size(); j++) { - Operator* op1 = graph.ops[j]; + Operator* op1 = graph->ops[j]; bool affacted = false; for (Operand* x : op1->inputs) @@ -1309,12 +1310,12 @@ static void functionize(Graph& graph) // 6. collect ops on the chain back to alias std::set chainsx_op_indexes; { - size_t op1_index = std::find(graph.ops.begin(), graph.ops.end(), op1) - graph.ops.begin(); + size_t op1_index = std::find(graph->ops.begin(), graph->ops.end(), op1) - graph->ops.begin(); if (op1_index < i - i_advanced) { chainsx_op_indexes.insert(op1_index); - // fprintf(stderr, "affacted op %s for %s\n", op1->name.c_str(), graph.operands[alias_index]->name.c_str()); + // fprintf(stderr, "affacted op %s for %s\n", op1->name.c_str(), graph->operands[alias_index]->name.c_str()); } while (1) @@ -1328,12 +1329,12 @@ static void functionize(Graph& graph) break; op1 = x->producer; - size_t op1_index = std::find(graph.ops.begin(), graph.ops.end(), op1) - graph.ops.begin(); + size_t op1_index = std::find(graph->ops.begin(), graph->ops.end(), op1) - graph->ops.begin(); if (op1_index < i - i_advanced) { chainsx_op_indexes.insert(op1_index); - // fprintf(stderr, "affacted op %s for %s chained\n", op1->name.c_str(), graph.operands[alias_index]->name.c_str()); + // fprintf(stderr, "affacted op %s for %s chained\n", op1->name.c_str(), graph->operands[alias_index]->name.c_str()); } } } @@ -1344,11 +1345,11 @@ static void functionize(Graph& graph) for (size_t doi : chainsx_op_indexes) { doi -= k; - // fprintf(stderr, "---> move %s after %s\n", graph.ops[doi]->name.c_str(), graph.ops[i - i_advanced]->name.c_str()); + // fprintf(stderr, "---> move %s after %s\n", graph->ops[doi]->name.c_str(), graph->ops[i - i_advanced]->name.c_str()); for (size_t l = doi; l < i - i_advanced; l++) { - std::swap(graph.ops[l], graph.ops[l + 1]); + std::swap(graph->ops[l], graph->ops[l + 1]); } k += 1; @@ -1359,10 +1360,10 @@ static void functionize(Graph& graph) // 8. update all alias uses after copy op, retag alias out0->params.erase("__alias__"); - const int new_alias_index = std::find(graph.operands.begin(), graph.operands.end(), out0) - graph.operands.begin(); - for (size_t k = i - i_advanced + 1; k < graph.ops.size(); k++) + const int new_alias_index = std::find(graph->operands.begin(), graph->operands.end(), out0) - graph->operands.begin(); + for (size_t k = i - i_advanced + 1; k < graph->ops.size(); k++) { - Operator* op2 = graph.ops[k]; + Operator* op2 = graph->ops[k]; // bool use_in0 = false; for (size_t l = 0; l < op2->inputs.size(); l++) @@ -1392,18 +1393,18 @@ static void functionize(Graph& graph) } } - // graph.save("4.param", "4.bin"); + // graph->save("4.param", "4.bin"); // 9. clear all alias tag { - for (Operand* x : graph.operands) + for (Operand* x : graph->operands) { x->params.erase("__alias__"); } } } -void pass_level2(Graph& g) +void pass_level2(std::shared_ptr g) { functionize(g); diff --git a/tools/pnnx/src/pass_level2.h b/tools/pnnx/src/pass_level2.h index 4c94faff56e..134623bee6f 100644 --- a/tools/pnnx/src/pass_level2.h +++ b/tools/pnnx/src/pass_level2.h @@ -58,9 +58,9 @@ class GraphRewriterPassRegister #define REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(CLASS, PRIORITY) \ static GraphRewriterPassRegister g_global_pnnx_graphrewriterpass_##CLASS##_register(new CLASS, PRIORITY); -void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opindex); +void pnnx_graph_rewrite(std::shared_ptr graph, const GraphRewriterPass* pass, int& opindex); -void pass_level2(Graph& g); +void pass_level2(std::shared_ptr g); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level3.cpp b/tools/pnnx/src/pass_level3.cpp index ed79d94f2d6..974162b74a1 100644 --- a/tools/pnnx/src/pass_level3.cpp +++ b/tools/pnnx/src/pass_level3.cpp @@ -37,7 +37,7 @@ namespace pnnx { -void pass_level3(Graph& g, const std::set& foldable_constants, const std::string& foldable_constants_zippath) +void pass_level3(std::shared_ptr g, const std::set& foldable_constants, const std::string& foldable_constants_zippath) { assign_unique_name(g); diff --git a/tools/pnnx/src/pass_level3.h b/tools/pnnx/src/pass_level3.h index 208d19e1eea..423bcc24bf9 100644 --- a/tools/pnnx/src/pass_level3.h +++ b/tools/pnnx/src/pass_level3.h @@ -19,7 +19,7 @@ namespace pnnx { -void pass_level3(Graph& g, const std::set& foldable_constants, const std::string& foldable_constants_zippath); +void pass_level3(std::shared_ptr g, const std::set& foldable_constants, const std::string& foldable_constants_zippath); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/assign_unique_name.cpp b/tools/pnnx/src/pass_level3/assign_unique_name.cpp index fa387e12b3f..4b34b1fbb67 100644 --- a/tools/pnnx/src/pass_level3/assign_unique_name.cpp +++ b/tools/pnnx/src/pass_level3/assign_unique_name.cpp @@ -17,16 +17,16 @@ namespace pnnx { -void assign_unique_name(Graph& graph) +void assign_unique_name(std::shared_ptr graph) { // assign unique name for all operators { std::unordered_set names; int make_unique_index = 0; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; const std::string& name = op->name; if (names.find(name) == names.end()) diff --git a/tools/pnnx/src/pass_level3/assign_unique_name.h b/tools/pnnx/src/pass_level3/assign_unique_name.h index afdd5b73cab..62edafb853b 100644 --- a/tools/pnnx/src/pass_level3/assign_unique_name.h +++ b/tools/pnnx/src/pass_level3/assign_unique_name.h @@ -16,6 +16,6 @@ namespace pnnx { -void assign_unique_name(Graph& graph); +void assign_unique_name(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/eliminate_noop_math.cpp b/tools/pnnx/src/pass_level3/eliminate_noop_math.cpp index 5822490ad10..9d143c7ace7 100644 --- a/tools/pnnx/src/pass_level3/eliminate_noop_math.cpp +++ b/tools/pnnx/src/pass_level3/eliminate_noop_math.cpp @@ -150,16 +150,16 @@ static bool operator_is_all_constant(const Operator* op, float vf, int vi) return false; } -void eliminate_noop_math(Graph& graph) +void eliminate_noop_math(std::shared_ptr graph) { for (;;) { bool need_eliminate = false; // build expression via reverse order - for (int i = (int)graph.ops.size() - 1; i >= 0; i--) + for (int i = (int)graph->ops.size() - 1; i >= 0; i--) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; int identity_input_id = 0; if (op->type == "aten::add") @@ -342,13 +342,13 @@ void eliminate_noop_math(Graph& graph) math_out->producer = 0; math_out->consumers.clear(); - graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), math_out)); + graph->operands.erase(std::find(graph->operands.begin(), graph->operands.end(), math_out)); delete math_out; op->inputs.clear(); op->outputs.clear(); - graph.ops.erase(graph.ops.begin() + i); + graph->ops.erase(graph->ops.begin() + i); delete op; break; diff --git a/tools/pnnx/src/pass_level3/eliminate_noop_math.h b/tools/pnnx/src/pass_level3/eliminate_noop_math.h index 08d0113c3f1..3224e2686f3 100644 --- a/tools/pnnx/src/pass_level3/eliminate_noop_math.h +++ b/tools/pnnx/src/pass_level3/eliminate_noop_math.h @@ -16,6 +16,6 @@ namespace pnnx { -void eliminate_noop_math(Graph& graph); +void eliminate_noop_math(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/eliminate_tuple_pair.cpp b/tools/pnnx/src/pass_level3/eliminate_tuple_pair.cpp index 013538f65ff..a1f22f0a155 100644 --- a/tools/pnnx/src/pass_level3/eliminate_tuple_pair.cpp +++ b/tools/pnnx/src/pass_level3/eliminate_tuple_pair.cpp @@ -19,15 +19,15 @@ namespace pnnx { -void eliminate_tuple_pair(Graph& graph) +void eliminate_tuple_pair(std::shared_ptr graph) { while (1) { bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "prim::TupleConstruct") continue; @@ -64,11 +64,11 @@ void eliminate_tuple_pair(Graph& graph) op2->outputs[j]->producer = 0; op2->outputs[j]->consumers.clear(); - graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), op2->outputs[j])); + graph->operands.erase(std::find(graph->operands.begin(), graph->operands.end(), op2->outputs[j])); delete op2->outputs[j]; } - graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), op->outputs[0])); + graph->operands.erase(std::find(graph->operands.begin(), graph->operands.end(), op->outputs[0])); delete op->outputs[0]; op->inputs.clear(); @@ -77,11 +77,11 @@ void eliminate_tuple_pair(Graph& graph) op2->inputs.clear(); op2->outputs.clear(); - graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op)); + graph->ops.erase(std::find(graph->ops.begin(), graph->ops.end(), op)); delete op; - graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op2)); + graph->ops.erase(std::find(graph->ops.begin(), graph->ops.end(), op2)); delete op2; diff --git a/tools/pnnx/src/pass_level3/eliminate_tuple_pair.h b/tools/pnnx/src/pass_level3/eliminate_tuple_pair.h index df70eda27a6..6fdc547e044 100644 --- a/tools/pnnx/src/pass_level3/eliminate_tuple_pair.h +++ b/tools/pnnx/src/pass_level3/eliminate_tuple_pair.h @@ -16,6 +16,6 @@ namespace pnnx { -void eliminate_tuple_pair(Graph& graph); +void eliminate_tuple_pair(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/expand_quantization_modules.cpp b/tools/pnnx/src/pass_level3/expand_quantization_modules.cpp index d856de94e16..ccdae972dbf 100644 --- a/tools/pnnx/src/pass_level3/expand_quantization_modules.cpp +++ b/tools/pnnx/src/pass_level3/expand_quantization_modules.cpp @@ -18,15 +18,15 @@ namespace pnnx { -void expand_quantization_modules(Graph& graph) +void expand_quantization_modules(std::shared_ptr graph) { while (1) { bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type == "nn.intrinsic.quantized.ConvReLU2d") { @@ -48,19 +48,19 @@ void expand_quantization_modules(Graph& graph) // insert new operator before all output consumers const Operator* cur = 0; { - int cur_index = graph.ops.size() - 1; + int cur_index = graph->ops.size() - 1; for (auto& c : op->outputs[0]->consumers) { - int c_index = std::find(graph.ops.begin(), graph.ops.end(), c) - graph.ops.begin(); + int c_index = std::find(graph->ops.begin(), graph->ops.end(), c) - graph->ops.begin(); cur_index = std::min(cur_index, c_index); } - cur = graph.ops[cur_index]; + cur = graph->ops[cur_index]; } - Operator* op_relu = graph.new_operator_before("nn.ReLU", op->name + "_relu", cur); + Operator* op_relu = graph->new_operator_before("nn.ReLU", op->name + "_relu", cur); - Operand* r0 = graph.new_operand(op->name + "_norelu"); + Operand* r0 = graph->new_operand(op->name + "_norelu"); r0->producer = op; r0->consumers.push_back(op_relu); diff --git a/tools/pnnx/src/pass_level3/expand_quantization_modules.h b/tools/pnnx/src/pass_level3/expand_quantization_modules.h index a57cbc8a20f..e5dfcc9d4c3 100644 --- a/tools/pnnx/src/pass_level3/expand_quantization_modules.h +++ b/tools/pnnx/src/pass_level3/expand_quantization_modules.h @@ -16,6 +16,6 @@ namespace pnnx { -void expand_quantization_modules(Graph& graph); +void expand_quantization_modules(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/fuse_dynamic_adaptive_pool.cpp b/tools/pnnx/src/pass_level3/fuse_dynamic_adaptive_pool.cpp index de532b18e21..74740854f1f 100644 --- a/tools/pnnx/src/pass_level3/fuse_dynamic_adaptive_pool.cpp +++ b/tools/pnnx/src/pass_level3/fuse_dynamic_adaptive_pool.cpp @@ -676,7 +676,7 @@ pnnx.Output output 2 0 out indices } }; -void fuse_dynamic_adaptive_pool(Graph& graph) +void fuse_dynamic_adaptive_pool(std::shared_ptr graph) { fuse_dynamic_adaptive_pool_pass a; fuse_dynamic_adaptive_pool_pass_1 b; diff --git a/tools/pnnx/src/pass_level3/fuse_dynamic_adaptive_pool.h b/tools/pnnx/src/pass_level3/fuse_dynamic_adaptive_pool.h index ce7a190a18f..a5ff8fe7d2c 100644 --- a/tools/pnnx/src/pass_level3/fuse_dynamic_adaptive_pool.h +++ b/tools/pnnx/src/pass_level3/fuse_dynamic_adaptive_pool.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_dynamic_adaptive_pool(Graph& graph); +void fuse_dynamic_adaptive_pool(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/fuse_einsum_operands.cpp b/tools/pnnx/src/pass_level3/fuse_einsum_operands.cpp index 58ef3ec2666..e6b611d854c 100644 --- a/tools/pnnx/src/pass_level3/fuse_einsum_operands.cpp +++ b/tools/pnnx/src/pass_level3/fuse_einsum_operands.cpp @@ -18,15 +18,15 @@ namespace pnnx { -void fuse_einsum_operands(Graph& graph) +void fuse_einsum_operands(std::shared_ptr graph) { while (1) { bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "torch.einsum") continue; @@ -74,7 +74,7 @@ void fuse_einsum_operands(Graph& graph) op2->inputs.clear(); op2->outputs.clear(); - graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op2)); + graph->ops.erase(std::find(graph->ops.begin(), graph->ops.end(), op2)); delete op2; diff --git a/tools/pnnx/src/pass_level3/fuse_einsum_operands.h b/tools/pnnx/src/pass_level3/fuse_einsum_operands.h index 21861b53445..7859cf46a60 100644 --- a/tools/pnnx/src/pass_level3/fuse_einsum_operands.h +++ b/tools/pnnx/src/pass_level3/fuse_einsum_operands.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_einsum_operands(Graph& graph); +void fuse_einsum_operands(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/fuse_expression.cpp b/tools/pnnx/src/pass_level3/fuse_expression.cpp index 2ed20abe7ce..1134b4a7274 100644 --- a/tools/pnnx/src/pass_level3/fuse_expression.cpp +++ b/tools/pnnx/src/pass_level3/fuse_expression.cpp @@ -133,7 +133,7 @@ static bool operand_maybe_tensor(const Operand* operand) return true; } -static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, std::vector& inputs, const std::set& foldable_constants, StoreZipReader& zip, bool checksubgraph = true) +static void fuse_expression(std::shared_ptr graph, Operand* operand, std::string& expr, std::vector& inputs, const std::set& foldable_constants, StoreZipReader& zip, bool checksubgraph = true) { // fprintf(stderr, "fuse_expression %s\n", operand->name.c_str()); @@ -728,7 +728,7 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s } } -void fuse_expression(Graph& graph, const std::set& foldable_constants, const std::string& foldable_constants_zippath) +void fuse_expression(std::shared_ptr graph, const std::set& foldable_constants, const std::string& foldable_constants_zippath) { StoreZipReader zip; zip.open(foldable_constants_zippath); @@ -740,9 +740,9 @@ void fuse_expression(Graph& graph, const std::set& foldable_constan bool need_fuse = false; // build expression via reverse order - for (int i = (int)graph.ops.size() - 1; i >= 0; i--) + for (int i = (int)graph->ops.size() - 1; i >= 0; i--) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type == "prim::Constant") { diff --git a/tools/pnnx/src/pass_level3/fuse_expression.h b/tools/pnnx/src/pass_level3/fuse_expression.h index 0bd70975971..8f3e8b3f839 100644 --- a/tools/pnnx/src/pass_level3/fuse_expression.h +++ b/tools/pnnx/src/pass_level3/fuse_expression.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_expression(Graph& graph, const std::set& foldable_constants, const std::string& foldable_constants_zippath); +void fuse_expression(std::shared_ptr graph, const std::set& foldable_constants, const std::string& foldable_constants_zippath); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/fuse_index_expression.cpp b/tools/pnnx/src/pass_level3/fuse_index_expression.cpp index 0c8c4d8e48d..718d8697004 100644 --- a/tools/pnnx/src/pass_level3/fuse_index_expression.cpp +++ b/tools/pnnx/src/pass_level3/fuse_index_expression.cpp @@ -27,7 +27,6 @@ static void replaceAll(std::string& str, const std::string& from, const std::str start_pos += to.length(); } } - static void multi_expr(int depth, std::vector& attr_shape, const int64_t* pdata, std::string& attr_expr, int pre_depth, int& cur_index) { if(depth == attr_shape.size() - 1) @@ -133,15 +132,15 @@ static std::string fuse_attribute_expression(Operator* op_expr) return expr; } -void fuse_index_expression(Graph& graph) +void fuse_index_expression(std::shared_ptr graph) { while (1) { bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "Tensor.index") continue; @@ -187,7 +186,7 @@ void fuse_index_expression(Graph& graph) op2->inputs.clear(); op2->outputs.clear(); - graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op2)); + graph->ops.erase(std::find(graph->ops.begin(), graph->ops.end(), op2)); delete op2; diff --git a/tools/pnnx/src/pass_level3/fuse_index_expression.h b/tools/pnnx/src/pass_level3/fuse_index_expression.h index 930cfa99cc6..cd681ad7c54 100644 --- a/tools/pnnx/src/pass_level3/fuse_index_expression.h +++ b/tools/pnnx/src/pass_level3/fuse_index_expression.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_index_expression(Graph& graph); +void fuse_index_expression(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/fuse_maxpool_unpack.cpp b/tools/pnnx/src/pass_level3/fuse_maxpool_unpack.cpp index 5cba43c00a1..40d37753c1b 100644 --- a/tools/pnnx/src/pass_level3/fuse_maxpool_unpack.cpp +++ b/tools/pnnx/src/pass_level3/fuse_maxpool_unpack.cpp @@ -18,15 +18,15 @@ namespace pnnx { -void fuse_maxpool_unpack(Graph& graph) +void fuse_maxpool_unpack(std::shared_ptr graph) { while (1) { bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "nn.MaxPool1d" && op->type != "nn.MaxPool2d" && op->type != "nn.MaxPool3d") continue; @@ -72,7 +72,7 @@ void fuse_maxpool_unpack(Graph& graph) op2->inputs.clear(); op2->outputs.clear(); - graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op2)); + graph->ops.erase(std::find(graph->ops.begin(), graph->ops.end(), op2)); delete op2; diff --git a/tools/pnnx/src/pass_level3/fuse_maxpool_unpack.h b/tools/pnnx/src/pass_level3/fuse_maxpool_unpack.h index 19bbf14425b..580a51022a8 100644 --- a/tools/pnnx/src/pass_level3/fuse_maxpool_unpack.h +++ b/tools/pnnx/src/pass_level3/fuse_maxpool_unpack.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_maxpool_unpack(Graph& graph); +void fuse_maxpool_unpack(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/fuse_multiheadattention_unpack.cpp b/tools/pnnx/src/pass_level3/fuse_multiheadattention_unpack.cpp index 29ad1aef545..7e9dd3483c1 100644 --- a/tools/pnnx/src/pass_level3/fuse_multiheadattention_unpack.cpp +++ b/tools/pnnx/src/pass_level3/fuse_multiheadattention_unpack.cpp @@ -18,15 +18,15 @@ namespace pnnx { -void fuse_multiheadattention_unpack(Graph& graph) +void fuse_multiheadattention_unpack(std::shared_ptr graph) { while (1) { bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "nn.MultiheadAttention") continue; @@ -56,7 +56,7 @@ void fuse_multiheadattention_unpack(Graph& graph) op2->inputs.clear(); op2->outputs.clear(); - graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op2)); + graph->ops.erase(std::find(graph->ops.begin(), graph->ops.end(), op2)); delete op2; diff --git a/tools/pnnx/src/pass_level3/fuse_multiheadattention_unpack.h b/tools/pnnx/src/pass_level3/fuse_multiheadattention_unpack.h index 7285cebec4e..0ffd4573696 100644 --- a/tools/pnnx/src/pass_level3/fuse_multiheadattention_unpack.h +++ b/tools/pnnx/src/pass_level3/fuse_multiheadattention_unpack.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_multiheadattention_unpack(Graph& graph); +void fuse_multiheadattention_unpack(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/fuse_op1ton_unpack.cpp b/tools/pnnx/src/pass_level3/fuse_op1ton_unpack.cpp index 882b4dd11e8..34cfb07e652 100644 --- a/tools/pnnx/src/pass_level3/fuse_op1ton_unpack.cpp +++ b/tools/pnnx/src/pass_level3/fuse_op1ton_unpack.cpp @@ -18,15 +18,15 @@ namespace pnnx { -void fuse_op1ton_unpack(Graph& graph) +void fuse_op1ton_unpack(std::shared_ptr graph) { while (1) { bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "torch.chunk" && op->type != "torch.split" && op->type != "torch.unbind" && op->type != "torch.tensor_split") continue; @@ -56,7 +56,7 @@ void fuse_op1ton_unpack(Graph& graph) op2->inputs.clear(); op2->outputs.clear(); - graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op2)); + graph->ops.erase(std::find(graph->ops.begin(), graph->ops.end(), op2)); delete op2; diff --git a/tools/pnnx/src/pass_level3/fuse_op1ton_unpack.h b/tools/pnnx/src/pass_level3/fuse_op1ton_unpack.h index a584c00cb4d..ec0306177da 100644 --- a/tools/pnnx/src/pass_level3/fuse_op1ton_unpack.h +++ b/tools/pnnx/src/pass_level3/fuse_op1ton_unpack.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_op1ton_unpack(Graph& graph); +void fuse_op1ton_unpack(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/fuse_opnto1_tensors.cpp b/tools/pnnx/src/pass_level3/fuse_opnto1_tensors.cpp index f3dcecbc727..43a16c5a844 100644 --- a/tools/pnnx/src/pass_level3/fuse_opnto1_tensors.cpp +++ b/tools/pnnx/src/pass_level3/fuse_opnto1_tensors.cpp @@ -18,15 +18,15 @@ namespace pnnx { -void fuse_opnto1_tensors(Graph& graph) +void fuse_opnto1_tensors(std::shared_ptr graph) { while (1) { bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "torch.cat" && op->type != "torch.stack") continue; @@ -67,7 +67,7 @@ void fuse_opnto1_tensors(Graph& graph) op2->inputs.clear(); op2->outputs.clear(); - graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op2)); + graph->ops.erase(std::find(graph->ops.begin(), graph->ops.end(), op2)); delete op2; diff --git a/tools/pnnx/src/pass_level3/fuse_opnto1_tensors.h b/tools/pnnx/src/pass_level3/fuse_opnto1_tensors.h index 4fb990a484c..a4bc301dd87 100644 --- a/tools/pnnx/src/pass_level3/fuse_opnto1_tensors.h +++ b/tools/pnnx/src/pass_level3/fuse_opnto1_tensors.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_opnto1_tensors(Graph& graph); +void fuse_opnto1_tensors(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/fuse_rnn_unpack.cpp b/tools/pnnx/src/pass_level3/fuse_rnn_unpack.cpp index 42ba4223439..86f2ba29ffa 100644 --- a/tools/pnnx/src/pass_level3/fuse_rnn_unpack.cpp +++ b/tools/pnnx/src/pass_level3/fuse_rnn_unpack.cpp @@ -18,15 +18,15 @@ namespace pnnx { -void fuse_rnn_unpack(Graph& graph) +void fuse_rnn_unpack(std::shared_ptr graph) { while (1) { bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "nn.RNN" && op->type != "nn.LSTM" && op->type != "nn.GRU") continue; @@ -71,7 +71,7 @@ void fuse_rnn_unpack(Graph& graph) op2->inputs.clear(); op2->outputs.clear(); - graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op2)); + graph->ops.erase(std::find(graph->ops.begin(), graph->ops.end(), op2)); delete op2; diff --git a/tools/pnnx/src/pass_level3/fuse_rnn_unpack.h b/tools/pnnx/src/pass_level3/fuse_rnn_unpack.h index 79f1f65df5a..2fb64fb4007 100644 --- a/tools/pnnx/src/pass_level3/fuse_rnn_unpack.h +++ b/tools/pnnx/src/pass_level3/fuse_rnn_unpack.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_rnn_unpack(Graph& graph); +void fuse_rnn_unpack(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/rename_F_conv_transposend.cpp b/tools/pnnx/src/pass_level3/rename_F_conv_transposend.cpp index 9f19e159828..dfae306d556 100644 --- a/tools/pnnx/src/pass_level3/rename_F_conv_transposend.cpp +++ b/tools/pnnx/src/pass_level3/rename_F_conv_transposend.cpp @@ -17,11 +17,11 @@ namespace pnnx { -void rename_F_conv_transposend(Graph& graph) +void rename_F_conv_transposend(std::shared_ptr graph) { - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "F.conv_transposend") continue; diff --git a/tools/pnnx/src/pass_level3/rename_F_conv_transposend.h b/tools/pnnx/src/pass_level3/rename_F_conv_transposend.h index 7192f28617e..2c708f45249 100644 --- a/tools/pnnx/src/pass_level3/rename_F_conv_transposend.h +++ b/tools/pnnx/src/pass_level3/rename_F_conv_transposend.h @@ -16,6 +16,6 @@ namespace pnnx { -void rename_F_conv_transposend(Graph& graph); +void rename_F_conv_transposend(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/rename_F_convmode.cpp b/tools/pnnx/src/pass_level3/rename_F_convmode.cpp index ba6424bdef9..c9f11a0a4af 100644 --- a/tools/pnnx/src/pass_level3/rename_F_convmode.cpp +++ b/tools/pnnx/src/pass_level3/rename_F_convmode.cpp @@ -17,11 +17,11 @@ namespace pnnx { -void rename_F_convmode(Graph& graph) +void rename_F_convmode(std::shared_ptr graph) { - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "F.convmode") continue; diff --git a/tools/pnnx/src/pass_level3/rename_F_convmode.h b/tools/pnnx/src/pass_level3/rename_F_convmode.h index 9057dff6e45..17cb6dabd3c 100644 --- a/tools/pnnx/src/pass_level3/rename_F_convmode.h +++ b/tools/pnnx/src/pass_level3/rename_F_convmode.h @@ -16,6 +16,6 @@ namespace pnnx { -void rename_F_convmode(Graph& graph); +void rename_F_convmode(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/rename_F_dropoutnd.cpp b/tools/pnnx/src/pass_level3/rename_F_dropoutnd.cpp index 3ddc7989af8..ccf7e3c252c 100644 --- a/tools/pnnx/src/pass_level3/rename_F_dropoutnd.cpp +++ b/tools/pnnx/src/pass_level3/rename_F_dropoutnd.cpp @@ -17,11 +17,11 @@ namespace pnnx { -void rename_F_dropoutnd(Graph& graph) +void rename_F_dropoutnd(std::shared_ptr graph) { - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "F.dropoutnd") continue; diff --git a/tools/pnnx/src/pass_level3/rename_F_dropoutnd.h b/tools/pnnx/src/pass_level3/rename_F_dropoutnd.h index 5bc48ea35ff..acfa048f7f0 100644 --- a/tools/pnnx/src/pass_level3/rename_F_dropoutnd.h +++ b/tools/pnnx/src/pass_level3/rename_F_dropoutnd.h @@ -16,6 +16,6 @@ namespace pnnx { -void rename_F_dropoutnd(Graph& graph); +void rename_F_dropoutnd(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level4.cpp b/tools/pnnx/src/pass_level4.cpp index 59553be8f07..a0a46400751 100644 --- a/tools/pnnx/src/pass_level4.cpp +++ b/tools/pnnx/src/pass_level4.cpp @@ -21,7 +21,7 @@ namespace pnnx { // add by senli -void pass_level4(Graph& g, std::set& custom_ops) +void pass_level4(std::shared_ptr g, std::set& custom_ops) { fuse_custom_op(g, custom_ops); diff --git a/tools/pnnx/src/pass_level4.h b/tools/pnnx/src/pass_level4.h index d54edba5317..1b47a6a6911 100644 --- a/tools/pnnx/src/pass_level4.h +++ b/tools/pnnx/src/pass_level4.h @@ -20,7 +20,7 @@ namespace pnnx { // add by senli -void pass_level4(Graph& g, std::set& custom_ops); +void pass_level4(std::shared_ptr g, std::set& custom_ops); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level4/canonicalize.cpp b/tools/pnnx/src/pass_level4/canonicalize.cpp index 65017e2fb3f..36d6516ee39 100644 --- a/tools/pnnx/src/pass_level4/canonicalize.cpp +++ b/tools/pnnx/src/pass_level4/canonicalize.cpp @@ -16,11 +16,11 @@ namespace pnnx { -void canonicalize(Graph& graph) +void canonicalize(std::shared_ptr graph) { int i = 0; - for (Operator* op : graph.ops) + for (Operator* op : graph->ops) { for (Operand* operand : op->outputs) { diff --git a/tools/pnnx/src/pass_level4/canonicalize.h b/tools/pnnx/src/pass_level4/canonicalize.h index e65f19e1c3b..62983da8643 100644 --- a/tools/pnnx/src/pass_level4/canonicalize.h +++ b/tools/pnnx/src/pass_level4/canonicalize.h @@ -16,6 +16,6 @@ namespace pnnx { -void canonicalize(Graph& graph); +void canonicalize(std::shared_ptr graph); } diff --git a/tools/pnnx/src/pass_level4/dead_code_elimination.cpp b/tools/pnnx/src/pass_level4/dead_code_elimination.cpp index 800bd671572..dc4d5b8cc8d 100644 --- a/tools/pnnx/src/pass_level4/dead_code_elimination.cpp +++ b/tools/pnnx/src/pass_level4/dead_code_elimination.cpp @@ -16,16 +16,16 @@ namespace pnnx { -void dead_code_elimination(Graph& graph) +void dead_code_elimination(std::shared_ptr graph) { // dead op elimination for (;;) { bool need_eliminate = false; - for (int i = (int)graph.ops.size() - 1; i >= 0; i--) + for (int i = (int)graph->ops.size() - 1; i >= 0; i--) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type == "pnnx.Output") continue; @@ -56,7 +56,7 @@ void dead_code_elimination(Graph& graph) op->outputs.clear(); - graph.ops.erase(graph.ops.begin() + i); + graph->ops.erase(graph->ops.begin() + i); delete op; break; @@ -72,9 +72,9 @@ void dead_code_elimination(Graph& graph) { bool need_eliminate = false; - for (int i = (int)graph.operands.size() - 1; i >= 0; i--) + for (int i = (int)graph->operands.size() - 1; i >= 0; i--) { - Operand* operand = graph.operands[i]; + Operand* operand = graph->operands[i]; int consumers = (int)operand->consumers.size(); @@ -84,7 +84,7 @@ void dead_code_elimination(Graph& graph) // fprintf(stderr, "delete operand %s\n", operand->name.c_str()); - graph.operands.erase(graph.operands.begin() + i); + graph->operands.erase(graph->operands.begin() + i); delete operand; break; diff --git a/tools/pnnx/src/pass_level4/dead_code_elimination.h b/tools/pnnx/src/pass_level4/dead_code_elimination.h index 145b40904e5..f8e0b948edb 100644 --- a/tools/pnnx/src/pass_level4/dead_code_elimination.h +++ b/tools/pnnx/src/pass_level4/dead_code_elimination.h @@ -16,6 +16,6 @@ namespace pnnx { -void dead_code_elimination(Graph& graph); +void dead_code_elimination(std::shared_ptr graph); } diff --git a/tools/pnnx/src/pass_level4/fuse_custom_op.cpp b/tools/pnnx/src/pass_level4/fuse_custom_op.cpp index 0b41928f6ea..c3aaaccab20 100644 --- a/tools/pnnx/src/pass_level4/fuse_custom_op.cpp +++ b/tools/pnnx/src/pass_level4/fuse_custom_op.cpp @@ -19,7 +19,7 @@ namespace pnnx { //add by senli -void fuse_custom_op(Graph& graph, std::set& custom_ops) +void fuse_custom_op(std::shared_ptr graph, std::set& custom_ops) { //add by senli //std::set custom_ops; @@ -29,9 +29,9 @@ void fuse_custom_op(Graph& graph, std::set& custom_ops) bool need_fuse = false; // fuse in reverse order - for (int i = (int)graph.ops.size() - 1; i >= 0; i--) + for (int i = (int)graph->ops.size() - 1; i >= 0; i--) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type.find("::") == std::string::npos) continue; @@ -46,7 +46,16 @@ void fuse_custom_op(Graph& graph, std::set& custom_ops) need_fuse = true; //add by senli // op->type = std::string("pnnx.custom_op.") + op_type_namespace + '.' + op_type_name; - op->type = std::string("torch.ops.") + op_type_namespace + '.' + op_type_name; + if (op_type_namespace == "torchvision") + { + op->type = op_type_namespace + ".ops." + op_type_name; + } + else + { + op->type = std::string("torch.ops.") + op_type_namespace + '.' + op_type_name; + } + + custom_ops.insert(op->type); std::vector new_inputs; std::vector new_inputnames; diff --git a/tools/pnnx/src/pass_level4/fuse_custom_op.h b/tools/pnnx/src/pass_level4/fuse_custom_op.h index 7ee9b53846c..3175c765fc0 100644 --- a/tools/pnnx/src/pass_level4/fuse_custom_op.h +++ b/tools/pnnx/src/pass_level4/fuse_custom_op.h @@ -17,6 +17,6 @@ namespace pnnx { // add by senli -void fuse_custom_op(Graph& graph, std::set& custom_ops); +void fuse_custom_op(std::shared_ptr graph, std::set& custom_ops); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5.cpp b/tools/pnnx/src/pass_level5.cpp index 7b91eabfe82..32ffc6bc9c3 100644 --- a/tools/pnnx/src/pass_level5.cpp +++ b/tools/pnnx/src/pass_level5.cpp @@ -65,7 +65,7 @@ namespace pnnx { -void pass_level5(Graph& g, const std::set& foldable_constants, const std::string& foldable_constants_zippath) +void pass_level5(std::shared_ptr g, const std::set& foldable_constants, const std::string& foldable_constants_zippath) { eval_expression(g); @@ -134,8 +134,9 @@ void pass_level5(Graph& g, const std::set& foldable_constants, cons eliminate_reshape_shape_expression(g); eliminate_noop_expand(g); - +#ifdef NDEBUG fuse_channel_shuffle(g); +#endif fuse_layernorm(g); fuse_multiheadattention(g); fuse_scaled_dot_product_attention(g); diff --git a/tools/pnnx/src/pass_level5.h b/tools/pnnx/src/pass_level5.h index a040c7bf145..f86ea6a67ed 100644 --- a/tools/pnnx/src/pass_level5.h +++ b/tools/pnnx/src/pass_level5.h @@ -19,7 +19,7 @@ namespace pnnx { -void pass_level5(Graph& g, const std::set& foldable_constants, const std::string& foldable_constants_zippath); +void pass_level5(std::shared_ptr g, const std::set& foldable_constants, const std::string& foldable_constants_zippath); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/attribute_unpooling.cpp b/tools/pnnx/src/pass_level5/attribute_unpooling.cpp index 76e7f762202..e58a9d19816 100644 --- a/tools/pnnx/src/pass_level5/attribute_unpooling.cpp +++ b/tools/pnnx/src/pass_level5/attribute_unpooling.cpp @@ -18,15 +18,15 @@ namespace pnnx { -void attribute_unpooling(Graph& graph) +void attribute_unpooling(std::shared_ptr graph) { while (1) { bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "pnnx.Attribute") continue; @@ -43,13 +43,13 @@ void attribute_unpooling(Graph& graph) { Operator* x = attr->consumers[i]; - Operator* op2 = graph.new_operator_after("pnnx.Attribute", op->name + "_" + std::to_string(i), op); + Operator* op2 = graph->new_operator_after("pnnx.Attribute", op->name + "_" + std::to_string(i), op); op2->inputnames = op->inputnames; op2->params = op->params; op2->attrs = op->attrs; - Operand* attr2 = graph.new_operand(attr->name + "_" + std::to_string(i)); + Operand* attr2 = graph->new_operand(attr->name + "_" + std::to_string(i)); attr2->type = attr->type; attr2->shape = attr->shape; diff --git a/tools/pnnx/src/pass_level5/attribute_unpooling.h b/tools/pnnx/src/pass_level5/attribute_unpooling.h index 333c709bebd..2fa8fb43d31 100644 --- a/tools/pnnx/src/pass_level5/attribute_unpooling.h +++ b/tools/pnnx/src/pass_level5/attribute_unpooling.h @@ -16,6 +16,6 @@ namespace pnnx { -void attribute_unpooling(Graph& g); +void attribute_unpooling(std::shared_ptr g); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/eliminate_dropout.cpp b/tools/pnnx/src/pass_level5/eliminate_dropout.cpp index 9809ac18930..12e14b605bb 100644 --- a/tools/pnnx/src/pass_level5/eliminate_dropout.cpp +++ b/tools/pnnx/src/pass_level5/eliminate_dropout.cpp @@ -19,15 +19,15 @@ namespace pnnx { -void eliminate_dropout(Graph& graph) +void eliminate_dropout(std::shared_ptr graph) { while (1) { bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "F.alpha_dropout" && op->type != "F.dropout" && op->type != "F.dropout2d" && op->type != "F.dropout3d" && op->type != "F.feature_alpha_dropout" && op->type != "nn.AlphaDropout" && op->type != "nn.Dropout" && op->type != "nn.Dropout2d" && op->type != "nn.Dropout3d") continue; @@ -58,13 +58,13 @@ void eliminate_dropout(Graph& graph) dropout_out->producer = 0; dropout_out->consumers.clear(); - graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), dropout_out)); + graph->operands.erase(std::find(graph->operands.begin(), graph->operands.end(), dropout_out)); delete dropout_out; op->inputs.clear(); op->outputs.clear(); - graph.ops.erase(graph.ops.begin() + i); + graph->ops.erase(graph->ops.begin() + i); delete op; break; diff --git a/tools/pnnx/src/pass_level5/eliminate_dropout.h b/tools/pnnx/src/pass_level5/eliminate_dropout.h index d3636611aa1..203c727172b 100644 --- a/tools/pnnx/src/pass_level5/eliminate_dropout.h +++ b/tools/pnnx/src/pass_level5/eliminate_dropout.h @@ -16,6 +16,6 @@ namespace pnnx { -void eliminate_dropout(Graph& graph); +void eliminate_dropout(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/eliminate_identity_operator.cpp b/tools/pnnx/src/pass_level5/eliminate_identity_operator.cpp index 02040406e55..ffee8f40931 100644 --- a/tools/pnnx/src/pass_level5/eliminate_identity_operator.cpp +++ b/tools/pnnx/src/pass_level5/eliminate_identity_operator.cpp @@ -19,24 +19,24 @@ namespace pnnx { -void eliminate_identity_operator(Graph& graph) +void eliminate_identity_operator(std::shared_ptr graph) { while (1) { bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op0 = graph.ops[i]; + Operator* op0 = graph->ops[i]; if (op0->type == "pnnx.Input" || op0->type == "pnnx.Output" || op0->type == "pnnx.Attribute" || op0->type == "torch.clone") continue; Operator* op1 = 0; - for (size_t j = i + 1; j < graph.ops.size(); j++) + for (size_t j = i + 1; j < graph->ops.size(); j++) { - op1 = graph.ops[j]; + op1 = graph->ops[j]; if (op1->type == "pnnx.Input" || op1->type == "pnnx.Output" || op0->type == "pnnx.Attribute" || op1->type == "torch.clone") continue; @@ -97,14 +97,14 @@ void eliminate_identity_operator(Graph& graph) // delete op1 and its output operands for (int j = 0; j < output_count; j++) { - graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), op1->outputs[j])); + graph->operands.erase(std::find(graph->operands.begin(), graph->operands.end(), op1->outputs[j])); delete op1->outputs[j]; } op1->inputs.clear(); op1->outputs.clear(); - graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op1)); + graph->ops.erase(std::find(graph->ops.begin(), graph->ops.end(), op1)); delete op1; break; diff --git a/tools/pnnx/src/pass_level5/eliminate_identity_operator.h b/tools/pnnx/src/pass_level5/eliminate_identity_operator.h index 7ff0299a2c8..7960d6cc548 100644 --- a/tools/pnnx/src/pass_level5/eliminate_identity_operator.h +++ b/tools/pnnx/src/pass_level5/eliminate_identity_operator.h @@ -16,6 +16,6 @@ namespace pnnx { -void eliminate_identity_operator(Graph& graph); +void eliminate_identity_operator(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/eliminate_maxpool_indices.cpp b/tools/pnnx/src/pass_level5/eliminate_maxpool_indices.cpp index a677fef3a31..21f28deacd8 100644 --- a/tools/pnnx/src/pass_level5/eliminate_maxpool_indices.cpp +++ b/tools/pnnx/src/pass_level5/eliminate_maxpool_indices.cpp @@ -19,15 +19,15 @@ namespace pnnx { -void eliminate_maxpool_indices(Graph& graph) +void eliminate_maxpool_indices(std::shared_ptr graph) { while (1) { bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "F.adaptive_max_pool1d" && op->type != "F.adaptive_max_pool2d" && op->type != "F.adaptive_max_pool3d" && op->type != "F.max_pool1d" && op->type != "F.max_pool2d" && op->type != "F.max_pool3d" @@ -56,7 +56,7 @@ void eliminate_maxpool_indices(Graph& graph) op_indices->producer = 0; - graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), op_indices)); + graph->operands.erase(std::find(graph->operands.begin(), graph->operands.end(), op_indices)); delete op_indices; break; diff --git a/tools/pnnx/src/pass_level5/eliminate_maxpool_indices.h b/tools/pnnx/src/pass_level5/eliminate_maxpool_indices.h index deecf31690a..28d436cf95b 100644 --- a/tools/pnnx/src/pass_level5/eliminate_maxpool_indices.h +++ b/tools/pnnx/src/pass_level5/eliminate_maxpool_indices.h @@ -16,6 +16,6 @@ namespace pnnx { -void eliminate_maxpool_indices(Graph& graph); +void eliminate_maxpool_indices(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/eliminate_noop_cat.cpp b/tools/pnnx/src/pass_level5/eliminate_noop_cat.cpp index f09911819ca..2e7781ebbba 100644 --- a/tools/pnnx/src/pass_level5/eliminate_noop_cat.cpp +++ b/tools/pnnx/src/pass_level5/eliminate_noop_cat.cpp @@ -19,15 +19,15 @@ namespace pnnx { -void eliminate_noop_cat(Graph& graph) +void eliminate_noop_cat(std::shared_ptr graph) { while (1) { bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "torch.cat") continue; @@ -58,13 +58,13 @@ void eliminate_noop_cat(Graph& graph) cat_out->producer = 0; cat_out->consumers.clear(); - graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), cat_out)); + graph->operands.erase(std::find(graph->operands.begin(), graph->operands.end(), cat_out)); delete cat_out; op->inputs.clear(); op->outputs.clear(); - graph.ops.erase(graph.ops.begin() + i); + graph->ops.erase(graph->ops.begin() + i); delete op; break; diff --git a/tools/pnnx/src/pass_level5/eliminate_noop_cat.h b/tools/pnnx/src/pass_level5/eliminate_noop_cat.h index 6186ed4958c..8ad9a3fc1c0 100644 --- a/tools/pnnx/src/pass_level5/eliminate_noop_cat.h +++ b/tools/pnnx/src/pass_level5/eliminate_noop_cat.h @@ -16,6 +16,6 @@ namespace pnnx { -void eliminate_noop_cat(Graph& graph); +void eliminate_noop_cat(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/eliminate_noop_einsum.cpp b/tools/pnnx/src/pass_level5/eliminate_noop_einsum.cpp index f20522a1600..40e17f01032 100644 --- a/tools/pnnx/src/pass_level5/eliminate_noop_einsum.cpp +++ b/tools/pnnx/src/pass_level5/eliminate_noop_einsum.cpp @@ -19,15 +19,15 @@ namespace pnnx { -void eliminate_noop_einsum(Graph& graph) +void eliminate_noop_einsum(std::shared_ptr graph) { while (1) { bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "torch.einsum") continue; @@ -63,13 +63,13 @@ void eliminate_noop_einsum(Graph& graph) einsum_out->producer = 0; einsum_out->consumers.clear(); - graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), einsum_out)); + graph->operands.erase(std::find(graph->operands.begin(), graph->operands.end(), einsum_out)); delete einsum_out; op->inputs.clear(); op->outputs.clear(); - graph.ops.erase(graph.ops.begin() + i); + graph->ops.erase(graph->ops.begin() + i); delete op; break; diff --git a/tools/pnnx/src/pass_level5/eliminate_noop_einsum.h b/tools/pnnx/src/pass_level5/eliminate_noop_einsum.h index 96e3811a40a..f2ec1d3fc71 100644 --- a/tools/pnnx/src/pass_level5/eliminate_noop_einsum.h +++ b/tools/pnnx/src/pass_level5/eliminate_noop_einsum.h @@ -16,6 +16,6 @@ namespace pnnx { -void eliminate_noop_einsum(Graph& graph); +void eliminate_noop_einsum(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/eliminate_noop_expand.cpp b/tools/pnnx/src/pass_level5/eliminate_noop_expand.cpp index 33257617b31..b7c8c07c429 100644 --- a/tools/pnnx/src/pass_level5/eliminate_noop_expand.cpp +++ b/tools/pnnx/src/pass_level5/eliminate_noop_expand.cpp @@ -19,15 +19,15 @@ namespace pnnx { -void eliminate_noop_expand(Graph& graph) +void eliminate_noop_expand(std::shared_ptr graph) { while (1) { bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "Tensor.expand_as" && op->type != "Tensor.expand") continue; @@ -136,13 +136,13 @@ void eliminate_noop_expand(Graph& graph) expand_out->producer = 0; expand_out->consumers.clear(); - graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), expand_out)); + graph->operands.erase(std::find(graph->operands.begin(), graph->operands.end(), expand_out)); delete expand_out; op->inputs.clear(); op->outputs.clear(); - graph.ops.erase(graph.ops.begin() + i); + graph->ops.erase(graph->ops.begin() + i); delete op; break; diff --git a/tools/pnnx/src/pass_level5/eliminate_noop_expand.h b/tools/pnnx/src/pass_level5/eliminate_noop_expand.h index cda5974a942..03e976e484b 100644 --- a/tools/pnnx/src/pass_level5/eliminate_noop_expand.h +++ b/tools/pnnx/src/pass_level5/eliminate_noop_expand.h @@ -16,6 +16,6 @@ namespace pnnx { -void eliminate_noop_expand(Graph& graph); +void eliminate_noop_expand(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/eliminate_noop_expression.cpp b/tools/pnnx/src/pass_level5/eliminate_noop_expression.cpp index 02f9a93422a..3264365c4d2 100644 --- a/tools/pnnx/src/pass_level5/eliminate_noop_expression.cpp +++ b/tools/pnnx/src/pass_level5/eliminate_noop_expression.cpp @@ -19,15 +19,15 @@ namespace pnnx { -void eliminate_noop_expression(Graph& graph) +void eliminate_noop_expression(std::shared_ptr graph) { while (1) { bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "pnnx.Expression") continue; @@ -65,13 +65,13 @@ void eliminate_noop_expression(Graph& graph) expr_out->producer = 0; expr_out->consumers.clear(); - graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), expr_out)); + graph->operands.erase(std::find(graph->operands.begin(), graph->operands.end(), expr_out)); delete expr_out; op->inputs.clear(); op->outputs.clear(); - graph.ops.erase(graph.ops.begin() + i); + graph->ops.erase(graph->ops.begin() + i); delete op; break; diff --git a/tools/pnnx/src/pass_level5/eliminate_noop_expression.h b/tools/pnnx/src/pass_level5/eliminate_noop_expression.h index 7c015eab186..0e4283a07fe 100644 --- a/tools/pnnx/src/pass_level5/eliminate_noop_expression.h +++ b/tools/pnnx/src/pass_level5/eliminate_noop_expression.h @@ -16,6 +16,6 @@ namespace pnnx { -void eliminate_noop_expression(Graph& graph); +void eliminate_noop_expression(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/eliminate_noop_pad.cpp b/tools/pnnx/src/pass_level5/eliminate_noop_pad.cpp index 7b1ca582784..caf1cb7c32b 100644 --- a/tools/pnnx/src/pass_level5/eliminate_noop_pad.cpp +++ b/tools/pnnx/src/pass_level5/eliminate_noop_pad.cpp @@ -19,15 +19,15 @@ namespace pnnx { -void eliminate_noop_pad(Graph& graph) +void eliminate_noop_pad(std::shared_ptr graph) { while (1) { bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "F.pad") continue; @@ -76,13 +76,13 @@ void eliminate_noop_pad(Graph& graph) pad_out->producer = 0; pad_out->consumers.clear(); - graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), pad_out)); + graph->operands.erase(std::find(graph->operands.begin(), graph->operands.end(), pad_out)); delete pad_out; op->inputs.clear(); op->outputs.clear(); - graph.ops.erase(graph.ops.begin() + i); + graph->ops.erase(graph->ops.begin() + i); delete op; break; diff --git a/tools/pnnx/src/pass_level5/eliminate_noop_pad.h b/tools/pnnx/src/pass_level5/eliminate_noop_pad.h index 359c046349b..f661c333965 100644 --- a/tools/pnnx/src/pass_level5/eliminate_noop_pad.h +++ b/tools/pnnx/src/pass_level5/eliminate_noop_pad.h @@ -16,6 +16,6 @@ namespace pnnx { -void eliminate_noop_pad(Graph& graph); +void eliminate_noop_pad(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/eliminate_noop_slice.cpp b/tools/pnnx/src/pass_level5/eliminate_noop_slice.cpp index 5e31b772897..2d6d5aec127 100644 --- a/tools/pnnx/src/pass_level5/eliminate_noop_slice.cpp +++ b/tools/pnnx/src/pass_level5/eliminate_noop_slice.cpp @@ -20,15 +20,15 @@ namespace pnnx { -void eliminate_noop_slice(Graph& graph) +void eliminate_noop_slice(std::shared_ptr graph) { while (1) { bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "Tensor.slice") continue; @@ -76,13 +76,13 @@ void eliminate_noop_slice(Graph& graph) slice_out->producer = 0; slice_out->consumers.clear(); - graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), slice_out)); + graph->operands.erase(std::find(graph->operands.begin(), graph->operands.end(), slice_out)); delete slice_out; op->inputs.clear(); op->outputs.clear(); - graph.ops.erase(graph.ops.begin() + i); + graph->ops.erase(graph->ops.begin() + i); delete op; break; diff --git a/tools/pnnx/src/pass_level5/eliminate_noop_slice.h b/tools/pnnx/src/pass_level5/eliminate_noop_slice.h index 162109d2a66..ba0ab2ee4d2 100644 --- a/tools/pnnx/src/pass_level5/eliminate_noop_slice.h +++ b/tools/pnnx/src/pass_level5/eliminate_noop_slice.h @@ -16,6 +16,6 @@ namespace pnnx { -void eliminate_noop_slice(Graph& graph); +void eliminate_noop_slice(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/eliminate_noop_upsample.cpp b/tools/pnnx/src/pass_level5/eliminate_noop_upsample.cpp index 30f56b768c6..2ce7831cfaa 100644 --- a/tools/pnnx/src/pass_level5/eliminate_noop_upsample.cpp +++ b/tools/pnnx/src/pass_level5/eliminate_noop_upsample.cpp @@ -19,15 +19,15 @@ namespace pnnx { -void eliminate_noop_upsample(Graph& graph) +void eliminate_noop_upsample(std::shared_ptr graph) { while (1) { bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "F.upsample" && op->type != "F.upsample_bilinear" && op->type != "F.upsample_nearest" && op->type != "F.interpolate" && op->type != "nn.Upsample" && op->type != "nn.UpsamplingBilinear2d" && op->type != "nn.UpsamplingNearest2d") @@ -104,13 +104,13 @@ void eliminate_noop_upsample(Graph& graph) upsample_out->producer = 0; upsample_out->consumers.clear(); - graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), upsample_out)); + graph->operands.erase(std::find(graph->operands.begin(), graph->operands.end(), upsample_out)); delete upsample_out; op->inputs.clear(); op->outputs.clear(); - graph.ops.erase(graph.ops.begin() + i); + graph->ops.erase(graph->ops.begin() + i); delete op; break; diff --git a/tools/pnnx/src/pass_level5/eliminate_noop_upsample.h b/tools/pnnx/src/pass_level5/eliminate_noop_upsample.h index 985781f225b..f059441e8b6 100644 --- a/tools/pnnx/src/pass_level5/eliminate_noop_upsample.h +++ b/tools/pnnx/src/pass_level5/eliminate_noop_upsample.h @@ -16,6 +16,6 @@ namespace pnnx { -void eliminate_noop_upsample(Graph& graph); +void eliminate_noop_upsample(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/eliminate_noop_view_reshape.cpp b/tools/pnnx/src/pass_level5/eliminate_noop_view_reshape.cpp index 1f79a7a6325..a7bbbc7d202 100644 --- a/tools/pnnx/src/pass_level5/eliminate_noop_view_reshape.cpp +++ b/tools/pnnx/src/pass_level5/eliminate_noop_view_reshape.cpp @@ -19,15 +19,15 @@ namespace pnnx { -void eliminate_noop_view_reshape(Graph& graph) +void eliminate_noop_view_reshape(std::shared_ptr graph) { while (1) { bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "Tensor.view" && op->type != "Tensor.reshape") continue; @@ -78,13 +78,13 @@ void eliminate_noop_view_reshape(Graph& graph) op_out->producer = 0; op_out->consumers.clear(); - graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), op_out)); + graph->operands.erase(std::find(graph->operands.begin(), graph->operands.end(), op_out)); delete op_out; op->inputs.clear(); op->outputs.clear(); - graph.ops.erase(graph.ops.begin() + i); + graph->ops.erase(graph->ops.begin() + i); delete op; break; diff --git a/tools/pnnx/src/pass_level5/eliminate_noop_view_reshape.h b/tools/pnnx/src/pass_level5/eliminate_noop_view_reshape.h index 1d724d99c41..45e0ec93073 100644 --- a/tools/pnnx/src/pass_level5/eliminate_noop_view_reshape.h +++ b/tools/pnnx/src/pass_level5/eliminate_noop_view_reshape.h @@ -16,6 +16,6 @@ namespace pnnx { -void eliminate_noop_view_reshape(Graph& graph); +void eliminate_noop_view_reshape(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/eliminate_reshape_shape_expression.cpp b/tools/pnnx/src/pass_level5/eliminate_reshape_shape_expression.cpp index 662f696857f..edca798b923 100644 --- a/tools/pnnx/src/pass_level5/eliminate_reshape_shape_expression.cpp +++ b/tools/pnnx/src/pass_level5/eliminate_reshape_shape_expression.cpp @@ -97,15 +97,15 @@ static std::string build_expr(const std::vector& expr_tokens) return expr; } -void eliminate_reshape_shape_expression(Graph& graph) +void eliminate_reshape_shape_expression(std::shared_ptr graph) { while (1) { bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "Tensor.view" && op->type != "Tensor.reshape") continue; @@ -172,13 +172,13 @@ void eliminate_reshape_shape_expression(Graph& graph) Operand* op_expr_out = op_expr->outputs[0]; - graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), op_expr_out)); + graph->operands.erase(std::find(graph->operands.begin(), graph->operands.end(), op_expr_out)); delete op_expr_out; op_expr->inputs.clear(); op_expr->outputs.clear(); - graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op_expr)); + graph->ops.erase(std::find(graph->ops.begin(), graph->ops.end(), op_expr)); delete op_expr; } @@ -189,9 +189,9 @@ void eliminate_reshape_shape_expression(Graph& graph) break; } - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "Tensor.view" && op->type != "Tensor.reshape") continue; diff --git a/tools/pnnx/src/pass_level5/eliminate_reshape_shape_expression.h b/tools/pnnx/src/pass_level5/eliminate_reshape_shape_expression.h index d4457c3acda..e62c5431d6e 100644 --- a/tools/pnnx/src/pass_level5/eliminate_reshape_shape_expression.h +++ b/tools/pnnx/src/pass_level5/eliminate_reshape_shape_expression.h @@ -16,6 +16,6 @@ namespace pnnx { -void eliminate_reshape_shape_expression(Graph& graph); +void eliminate_reshape_shape_expression(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/eliminate_type_as.cpp b/tools/pnnx/src/pass_level5/eliminate_type_as.cpp index c7290fb0480..7c9a61b6b62 100644 --- a/tools/pnnx/src/pass_level5/eliminate_type_as.cpp +++ b/tools/pnnx/src/pass_level5/eliminate_type_as.cpp @@ -19,15 +19,15 @@ namespace pnnx { -void eliminate_type_as(Graph& graph) +void eliminate_type_as(std::shared_ptr graph) { while (1) { bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "Tensor.type_as") continue; @@ -64,13 +64,13 @@ void eliminate_type_as(Graph& graph) type_as_out->producer = 0; type_as_out->consumers.clear(); - graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), type_as_out)); + graph->operands.erase(std::find(graph->operands.begin(), graph->operands.end(), type_as_out)); delete type_as_out; op->inputs.clear(); op->outputs.clear(); - graph.ops.erase(graph.ops.begin() + i); + graph->ops.erase(graph->ops.begin() + i); delete op; break; diff --git a/tools/pnnx/src/pass_level5/eliminate_type_as.h b/tools/pnnx/src/pass_level5/eliminate_type_as.h index 46ec5ad571b..47f642b06f5 100644 --- a/tools/pnnx/src/pass_level5/eliminate_type_as.h +++ b/tools/pnnx/src/pass_level5/eliminate_type_as.h @@ -16,6 +16,6 @@ namespace pnnx { -void eliminate_type_as(Graph& graph); +void eliminate_type_as(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/eval_expression.cpp b/tools/pnnx/src/pass_level5/eval_expression.cpp index 10b38b9fcb8..9d80ee8b85a 100644 --- a/tools/pnnx/src/pass_level5/eval_expression.cpp +++ b/tools/pnnx/src/pass_level5/eval_expression.cpp @@ -626,9 +626,9 @@ static std::string canonicalize_arguments(const Operator* op, std::vector graph) { - for (Operator* op : graph.ops) + for (Operator* op : graph->ops) { if (op->type != "pnnx.Expression") continue; diff --git a/tools/pnnx/src/pass_level5/eval_expression.h b/tools/pnnx/src/pass_level5/eval_expression.h index 149ef82ce79..7f461d22462 100644 --- a/tools/pnnx/src/pass_level5/eval_expression.h +++ b/tools/pnnx/src/pass_level5/eval_expression.h @@ -16,6 +16,6 @@ namespace pnnx { -void eval_expression(Graph& graph); +void eval_expression(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fold_constants.cpp b/tools/pnnx/src/pass_level5/fold_constants.cpp index 47906f76175..10d07c88989 100644 --- a/tools/pnnx/src/pass_level5/fold_constants.cpp +++ b/tools/pnnx/src/pass_level5/fold_constants.cpp @@ -20,7 +20,7 @@ namespace pnnx { -void fold_constants(Graph& graph, const std::set& foldable_constants, const std::string& foldable_constants_zippath) +void fold_constants(std::shared_ptr graph, const std::set& foldable_constants, const std::string& foldable_constants_zippath) { if (foldable_constants.empty()) return; @@ -28,9 +28,9 @@ void fold_constants(Graph& graph, const std::set& foldable_constant StoreZipReader zip; zip.open(foldable_constants_zippath); - for (size_t i = 0; i < graph.operands.size(); i++) + for (size_t i = 0; i < graph->operands.size(); i++) { - Operand* operand = graph.operands[i]; + Operand* operand = graph->operands[i]; const std::string& name = operand->name; if (foldable_constants.find(name) == foldable_constants.end()) @@ -41,7 +41,7 @@ void fold_constants(Graph& graph, const std::set& foldable_constant continue; // replace producer with attribute - Operator* op_new = graph.new_operator_before("pnnx.Attribute", std::string("pnnx_fold_") + name, op); + Operator* op_new = graph->new_operator_before("pnnx.Attribute", std::string("pnnx_fold_") + name, op); op_new->attrs["data"] = Attribute(); diff --git a/tools/pnnx/src/pass_level5/fold_constants.h b/tools/pnnx/src/pass_level5/fold_constants.h index 0d96f9fbd0c..91c0c67ec4b 100644 --- a/tools/pnnx/src/pass_level5/fold_constants.h +++ b/tools/pnnx/src/pass_level5/fold_constants.h @@ -16,6 +16,6 @@ namespace pnnx { -void fold_constants(Graph& graph, const std::set& foldable_constants, const std::string& foldable_constants_zippath); +void fold_constants(std::shared_ptr graph, const std::set& foldable_constants, const std::string& foldable_constants_zippath); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_adjacent_reshape.cpp b/tools/pnnx/src/pass_level5/fuse_adjacent_reshape.cpp index f8505072129..75df7614e6d 100644 --- a/tools/pnnx/src/pass_level5/fuse_adjacent_reshape.cpp +++ b/tools/pnnx/src/pass_level5/fuse_adjacent_reshape.cpp @@ -19,15 +19,15 @@ namespace pnnx { -void fuse_adjacent_reshape(Graph& graph) +void fuse_adjacent_reshape(std::shared_ptr graph) { while (1) { bool matched = false; - for (int i = (int)graph.ops.size() - 1; i > 0; i--) + for (int i = (int)graph->ops.size() - 1; i > 0; i--) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; // look for Tensor.view / Tensor.reshape / torch.squeeze / torch.unsqueeze chain if (op->type != "Tensor.view" && op->type != "Tensor.reshape" && op->type != "torch.squeeze" && op->type != "torch.unsqueeze") @@ -84,13 +84,13 @@ void fuse_adjacent_reshape(Graph& graph) op0_out->producer = 0; op0_out->consumers.clear(); - graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), op0_out)); + graph->operands.erase(std::find(graph->operands.begin(), graph->operands.end(), op0_out)); delete op0_out; op0->inputs.clear(); op0->outputs.clear(); - graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op0)); + graph->ops.erase(std::find(graph->ops.begin(), graph->ops.end(), op0)); delete op0; } diff --git a/tools/pnnx/src/pass_level5/fuse_adjacent_reshape.h b/tools/pnnx/src/pass_level5/fuse_adjacent_reshape.h index 7f3fb51cdf3..14ee2c2a2ef 100644 --- a/tools/pnnx/src/pass_level5/fuse_adjacent_reshape.h +++ b/tools/pnnx/src/pass_level5/fuse_adjacent_reshape.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_adjacent_reshape(Graph& graph); +void fuse_adjacent_reshape(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_channel_shuffle.cpp b/tools/pnnx/src/pass_level5/fuse_channel_shuffle.cpp index 3a38f594179..c1410c85f96 100644 --- a/tools/pnnx/src/pass_level5/fuse_channel_shuffle.cpp +++ b/tools/pnnx/src/pass_level5/fuse_channel_shuffle.cpp @@ -77,7 +77,7 @@ pnnx.Output output 1 0 out } }; -void fuse_channel_shuffle(Graph& graph) +void fuse_channel_shuffle(std::shared_ptr graph) { fuse_channel_shuffle_pass a; fuse_channel_shuffle_pass_1 b; diff --git a/tools/pnnx/src/pass_level5/fuse_channel_shuffle.h b/tools/pnnx/src/pass_level5/fuse_channel_shuffle.h index 3257f9cda97..d15b226d5aa 100644 --- a/tools/pnnx/src/pass_level5/fuse_channel_shuffle.h +++ b/tools/pnnx/src/pass_level5/fuse_channel_shuffle.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_channel_shuffle(Graph& graph); +void fuse_channel_shuffle(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_constant_expression.cpp b/tools/pnnx/src/pass_level5/fuse_constant_expression.cpp index 7f04b6458af..7422f39a624 100644 --- a/tools/pnnx/src/pass_level5/fuse_constant_expression.cpp +++ b/tools/pnnx/src/pass_level5/fuse_constant_expression.cpp @@ -20,15 +20,15 @@ namespace pnnx { -void fuse_constant_expression(Graph& graph) +void fuse_constant_expression(std::shared_ptr graph) { while (1) { bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "pnnx.Expression") continue; @@ -94,13 +94,13 @@ void fuse_constant_expression(Graph& graph) // delete expression and expr_output expr_output->producer = 0; - graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), expr_output)); + graph->operands.erase(std::find(graph->operands.begin(), graph->operands.end(), expr_output)); delete expr_output; op->inputs.clear(); op->outputs.clear(); - graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op)); + graph->ops.erase(std::find(graph->ops.begin(), graph->ops.end(), op)); delete op; } diff --git a/tools/pnnx/src/pass_level5/fuse_constant_expression.h b/tools/pnnx/src/pass_level5/fuse_constant_expression.h index bb6a5937bd7..93f4b8cb4d6 100644 --- a/tools/pnnx/src/pass_level5/fuse_constant_expression.h +++ b/tools/pnnx/src/pass_level5/fuse_constant_expression.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_constant_expression(Graph& graph); +void fuse_constant_expression(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_contiguous_view.cpp b/tools/pnnx/src/pass_level5/fuse_contiguous_view.cpp index e42a72c9227..7c0c54b09ee 100644 --- a/tools/pnnx/src/pass_level5/fuse_contiguous_view.cpp +++ b/tools/pnnx/src/pass_level5/fuse_contiguous_view.cpp @@ -69,7 +69,7 @@ pnnx.Output output 1 0 out } }; -void fuse_contiguous_view(Graph& graph) +void fuse_contiguous_view(std::shared_ptr graph) { fuse_contiguous_view_pass a; fuse_contiguous_view_pass_1 b; diff --git a/tools/pnnx/src/pass_level5/fuse_contiguous_view.h b/tools/pnnx/src/pass_level5/fuse_contiguous_view.h index 33612c867a9..b87b05b51bf 100644 --- a/tools/pnnx/src/pass_level5/fuse_contiguous_view.h +++ b/tools/pnnx/src/pass_level5/fuse_contiguous_view.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_contiguous_view(Graph& graph); +void fuse_contiguous_view(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_conv1d_batchnorm1d.cpp b/tools/pnnx/src/pass_level5/fuse_conv1d_batchnorm1d.cpp index c471dc6a52e..601dee887df 100644 --- a/tools/pnnx/src/pass_level5/fuse_conv1d_batchnorm1d.cpp +++ b/tools/pnnx/src/pass_level5/fuse_conv1d_batchnorm1d.cpp @@ -127,7 +127,7 @@ pnnx.Output output 1 0 out } }; -void fuse_conv1d_batchnorm1d(Graph& graph) +void fuse_conv1d_batchnorm1d(std::shared_ptr graph) { fuse_conv1d_batchnorm1d_pass a; int opindex = 0; diff --git a/tools/pnnx/src/pass_level5/fuse_conv1d_batchnorm1d.h b/tools/pnnx/src/pass_level5/fuse_conv1d_batchnorm1d.h index 89a2bd46504..8105a920904 100644 --- a/tools/pnnx/src/pass_level5/fuse_conv1d_batchnorm1d.h +++ b/tools/pnnx/src/pass_level5/fuse_conv1d_batchnorm1d.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_conv1d_batchnorm1d(Graph& graph); +void fuse_conv1d_batchnorm1d(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_conv2d_batchnorm2d.cpp b/tools/pnnx/src/pass_level5/fuse_conv2d_batchnorm2d.cpp index 3c207c71a77..094c12be5f3 100644 --- a/tools/pnnx/src/pass_level5/fuse_conv2d_batchnorm2d.cpp +++ b/tools/pnnx/src/pass_level5/fuse_conv2d_batchnorm2d.cpp @@ -127,7 +127,7 @@ pnnx.Output output 1 0 out } }; -void fuse_conv2d_batchnorm2d(Graph& graph) +void fuse_conv2d_batchnorm2d(std::shared_ptr graph) { fuse_conv2d_batchnorm2d_pass a; int opindex = 0; diff --git a/tools/pnnx/src/pass_level5/fuse_conv2d_batchnorm2d.h b/tools/pnnx/src/pass_level5/fuse_conv2d_batchnorm2d.h index 829aa5a6da2..e12ddcf9348 100644 --- a/tools/pnnx/src/pass_level5/fuse_conv2d_batchnorm2d.h +++ b/tools/pnnx/src/pass_level5/fuse_conv2d_batchnorm2d.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_conv2d_batchnorm2d(Graph& graph); +void fuse_conv2d_batchnorm2d(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_conv3d_batchnorm3d.cpp b/tools/pnnx/src/pass_level5/fuse_conv3d_batchnorm3d.cpp index ea89e99cb53..4090875a1f1 100644 --- a/tools/pnnx/src/pass_level5/fuse_conv3d_batchnorm3d.cpp +++ b/tools/pnnx/src/pass_level5/fuse_conv3d_batchnorm3d.cpp @@ -127,7 +127,7 @@ pnnx.Output output 1 0 out } }; -void fuse_conv3d_batchnorm3d(Graph& graph) +void fuse_conv3d_batchnorm3d(std::shared_ptr graph) { fuse_conv3d_batchnorm3d_pass a; int opindex = 0; diff --git a/tools/pnnx/src/pass_level5/fuse_conv3d_batchnorm3d.h b/tools/pnnx/src/pass_level5/fuse_conv3d_batchnorm3d.h index 017201d4d8b..079ee843df0 100644 --- a/tools/pnnx/src/pass_level5/fuse_conv3d_batchnorm3d.h +++ b/tools/pnnx/src/pass_level5/fuse_conv3d_batchnorm3d.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_conv3d_batchnorm3d(Graph& graph); +void fuse_conv3d_batchnorm3d(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_convtranspose1d_batchnorm1d.cpp b/tools/pnnx/src/pass_level5/fuse_convtranspose1d_batchnorm1d.cpp index 0b0603ffead..8a0bc169240 100644 --- a/tools/pnnx/src/pass_level5/fuse_convtranspose1d_batchnorm1d.cpp +++ b/tools/pnnx/src/pass_level5/fuse_convtranspose1d_batchnorm1d.cpp @@ -142,7 +142,7 @@ pnnx.Output output 1 0 out } }; -void fuse_convtranspose1d_batchnorm1d(Graph& graph) +void fuse_convtranspose1d_batchnorm1d(std::shared_ptr graph) { fuse_convtranspose1d_batchnorm1d_pass a; int opindex = 0; diff --git a/tools/pnnx/src/pass_level5/fuse_convtranspose1d_batchnorm1d.h b/tools/pnnx/src/pass_level5/fuse_convtranspose1d_batchnorm1d.h index 68ced8ecc36..694a950faef 100644 --- a/tools/pnnx/src/pass_level5/fuse_convtranspose1d_batchnorm1d.h +++ b/tools/pnnx/src/pass_level5/fuse_convtranspose1d_batchnorm1d.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_convtranspose1d_batchnorm1d(Graph& graph); +void fuse_convtranspose1d_batchnorm1d(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_convtranspose2d_batchnorm2d.cpp b/tools/pnnx/src/pass_level5/fuse_convtranspose2d_batchnorm2d.cpp index fc3244654bf..d1002a6939f 100644 --- a/tools/pnnx/src/pass_level5/fuse_convtranspose2d_batchnorm2d.cpp +++ b/tools/pnnx/src/pass_level5/fuse_convtranspose2d_batchnorm2d.cpp @@ -144,7 +144,7 @@ pnnx.Output output 1 0 out } }; -void fuse_convtranspose2d_batchnorm2d(Graph& graph) +void fuse_convtranspose2d_batchnorm2d(std::shared_ptr graph) { fuse_convtranspose2d_batchnorm2d_pass a; int opindex = 0; diff --git a/tools/pnnx/src/pass_level5/fuse_convtranspose2d_batchnorm2d.h b/tools/pnnx/src/pass_level5/fuse_convtranspose2d_batchnorm2d.h index 854b72a250f..7413cf6ebc9 100644 --- a/tools/pnnx/src/pass_level5/fuse_convtranspose2d_batchnorm2d.h +++ b/tools/pnnx/src/pass_level5/fuse_convtranspose2d_batchnorm2d.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_convtranspose2d_batchnorm2d(Graph& graph); +void fuse_convtranspose2d_batchnorm2d(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_convtranspose3d_batchnorm3d.cpp b/tools/pnnx/src/pass_level5/fuse_convtranspose3d_batchnorm3d.cpp index d01eebeed48..3bb691c267f 100644 --- a/tools/pnnx/src/pass_level5/fuse_convtranspose3d_batchnorm3d.cpp +++ b/tools/pnnx/src/pass_level5/fuse_convtranspose3d_batchnorm3d.cpp @@ -145,7 +145,7 @@ pnnx.Output output 1 0 out } }; -void fuse_convtranspose3d_batchnorm3d(Graph& graph) +void fuse_convtranspose3d_batchnorm3d(std::shared_ptr graph) { fuse_convtranspose3d_batchnorm3d_pass a; int opindex = 0; diff --git a/tools/pnnx/src/pass_level5/fuse_convtranspose3d_batchnorm3d.h b/tools/pnnx/src/pass_level5/fuse_convtranspose3d_batchnorm3d.h index f15e2f41c66..8eb757df345 100644 --- a/tools/pnnx/src/pass_level5/fuse_convtranspose3d_batchnorm3d.h +++ b/tools/pnnx/src/pass_level5/fuse_convtranspose3d_batchnorm3d.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_convtranspose3d_batchnorm3d(Graph& graph); +void fuse_convtranspose3d_batchnorm3d(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_layernorm.cpp b/tools/pnnx/src/pass_level5/fuse_layernorm.cpp index c52201f8922..47b9d55475e 100644 --- a/tools/pnnx/src/pass_level5/fuse_layernorm.cpp +++ b/tools/pnnx/src/pass_level5/fuse_layernorm.cpp @@ -76,7 +76,7 @@ pnnx.Output output 1 0 out } }; -void fuse_layernorm(Graph& graph) +void fuse_layernorm(std::shared_ptr graph) { fuse_layernorm_pass a; int opindex = 0; diff --git a/tools/pnnx/src/pass_level5/fuse_layernorm.h b/tools/pnnx/src/pass_level5/fuse_layernorm.h index ac8c82cf80c..46464e46882 100644 --- a/tools/pnnx/src/pass_level5/fuse_layernorm.h +++ b/tools/pnnx/src/pass_level5/fuse_layernorm.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_layernorm(Graph& graph); +void fuse_layernorm(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_linear_batchnorm1d.cpp b/tools/pnnx/src/pass_level5/fuse_linear_batchnorm1d.cpp index 679cf7ec9fa..cc05f1443af 100644 --- a/tools/pnnx/src/pass_level5/fuse_linear_batchnorm1d.cpp +++ b/tools/pnnx/src/pass_level5/fuse_linear_batchnorm1d.cpp @@ -120,7 +120,7 @@ pnnx.Output output 1 0 out } }; -void fuse_linear_batchnorm1d(Graph& graph) +void fuse_linear_batchnorm1d(std::shared_ptr graph) { fuse_linear_batchnorm1d_pass a; int opindex = 0; diff --git a/tools/pnnx/src/pass_level5/fuse_linear_batchnorm1d.h b/tools/pnnx/src/pass_level5/fuse_linear_batchnorm1d.h index b04e03332eb..fd6044f5ee1 100644 --- a/tools/pnnx/src/pass_level5/fuse_linear_batchnorm1d.h +++ b/tools/pnnx/src/pass_level5/fuse_linear_batchnorm1d.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_linear_batchnorm1d(Graph& graph); +void fuse_linear_batchnorm1d(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp b/tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp index 336ad9dfbbb..0bbafb2b34a 100644 --- a/tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp +++ b/tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp @@ -1368,7 +1368,7 @@ pnnx.Output output 1 0 out } }; -void fuse_multiheadattention(Graph& graph) +void fuse_multiheadattention(std::shared_ptr graph) { #if TORCH_VERSION_MAJOR >= 2 || (TORCH_VERSION_MAJOR >= 1 && TORCH_VERSION_MINOR >= 9) fuse_multiheadattention_pass a; diff --git a/tools/pnnx/src/pass_level5/fuse_multiheadattention.h b/tools/pnnx/src/pass_level5/fuse_multiheadattention.h index d8c1914d24e..ec81ca98a2b 100644 --- a/tools/pnnx/src/pass_level5/fuse_multiheadattention.h +++ b/tools/pnnx/src/pass_level5/fuse_multiheadattention.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_multiheadattention(Graph& graph); +void fuse_multiheadattention(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_pad_conv1d.cpp b/tools/pnnx/src/pass_level5/fuse_pad_conv1d.cpp index c78db4c66d9..42c3800de13 100644 --- a/tools/pnnx/src/pass_level5/fuse_pad_conv1d.cpp +++ b/tools/pnnx/src/pass_level5/fuse_pad_conv1d.cpp @@ -382,7 +382,7 @@ pnnx.Output output 1 0 out } }; -void fuse_pad_conv1d(Graph& graph) +void fuse_pad_conv1d(std::shared_ptr graph) { fuse_pad_conv1d_pass a; fuse_pad_conv1d_pass_1 b; diff --git a/tools/pnnx/src/pass_level5/fuse_pad_conv1d.h b/tools/pnnx/src/pass_level5/fuse_pad_conv1d.h index f121b340cb0..baa7aa8d0e7 100644 --- a/tools/pnnx/src/pass_level5/fuse_pad_conv1d.h +++ b/tools/pnnx/src/pass_level5/fuse_pad_conv1d.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_pad_conv1d(Graph& graph); +void fuse_pad_conv1d(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_pad_conv2d.cpp b/tools/pnnx/src/pass_level5/fuse_pad_conv2d.cpp index 0823168cea8..b28b6d9f0c4 100644 --- a/tools/pnnx/src/pass_level5/fuse_pad_conv2d.cpp +++ b/tools/pnnx/src/pass_level5/fuse_pad_conv2d.cpp @@ -479,7 +479,7 @@ pnnx.Output output 1 0 out } }; -void fuse_pad_conv2d(Graph& graph) +void fuse_pad_conv2d(std::shared_ptr graph) { fuse_pad_conv2d_pass a; fuse_pad_conv2d_pass_1 b; diff --git a/tools/pnnx/src/pass_level5/fuse_pad_conv2d.h b/tools/pnnx/src/pass_level5/fuse_pad_conv2d.h index fb47be50ec7..5f3661b072f 100644 --- a/tools/pnnx/src/pass_level5/fuse_pad_conv2d.h +++ b/tools/pnnx/src/pass_level5/fuse_pad_conv2d.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_pad_conv2d(Graph& graph); +void fuse_pad_conv2d(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_pixel_unshuffle.cpp b/tools/pnnx/src/pass_level5/fuse_pixel_unshuffle.cpp index 7a24093888f..e2b137b4366 100644 --- a/tools/pnnx/src/pass_level5/fuse_pixel_unshuffle.cpp +++ b/tools/pnnx/src/pass_level5/fuse_pixel_unshuffle.cpp @@ -72,7 +72,7 @@ pnnx.Output output 1 0 out } }; -void fuse_pixel_unshuffle(Graph& graph) +void fuse_pixel_unshuffle(std::shared_ptr graph) { fuse_pixel_unshuffle_pass a; int opindex = 0; diff --git a/tools/pnnx/src/pass_level5/fuse_pixel_unshuffle.h b/tools/pnnx/src/pass_level5/fuse_pixel_unshuffle.h index d852d5c4346..78b7e79bce2 100644 --- a/tools/pnnx/src/pass_level5/fuse_pixel_unshuffle.h +++ b/tools/pnnx/src/pass_level5/fuse_pixel_unshuffle.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_pixel_unshuffle(Graph& graph); +void fuse_pixel_unshuffle(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_scaled_dot_product_attention.cpp b/tools/pnnx/src/pass_level5/fuse_scaled_dot_product_attention.cpp index 38f1375445b..b6022ca9493 100644 --- a/tools/pnnx/src/pass_level5/fuse_scaled_dot_product_attention.cpp +++ b/tools/pnnx/src/pass_level5/fuse_scaled_dot_product_attention.cpp @@ -137,7 +137,7 @@ pnnx.Output output 1 0 out } }; -void fuse_scaled_dot_product_attention(Graph& graph) +void fuse_scaled_dot_product_attention(std::shared_ptr graph) { #if TORCH_VERSION_MAJOR >= 2 fuse_scaled_dot_product_attention_pass a; diff --git a/tools/pnnx/src/pass_level5/fuse_scaled_dot_product_attention.h b/tools/pnnx/src/pass_level5/fuse_scaled_dot_product_attention.h index 0eb13015c9a..015f04d9536 100644 --- a/tools/pnnx/src/pass_level5/fuse_scaled_dot_product_attention.h +++ b/tools/pnnx/src/pass_level5/fuse_scaled_dot_product_attention.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_scaled_dot_product_attention(Graph& graph); +void fuse_scaled_dot_product_attention(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_select_to_unbind.cpp b/tools/pnnx/src/pass_level5/fuse_select_to_unbind.cpp index 5a21f45c5db..4762ab59bda 100644 --- a/tools/pnnx/src/pass_level5/fuse_select_to_unbind.cpp +++ b/tools/pnnx/src/pass_level5/fuse_select_to_unbind.cpp @@ -19,15 +19,15 @@ namespace pnnx { -void fuse_select_to_unbind(Graph& graph) +void fuse_select_to_unbind(std::shared_ptr graph) { while (1) { bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "Tensor.select") continue; @@ -86,7 +86,7 @@ void fuse_select_to_unbind(Graph& graph) matched = true; // delete all select ops and replace with unbind - Operator* op_unbind = graph.new_operator_before("torch.unbind", op->name, op); + Operator* op_unbind = graph->new_operator_before("torch.unbind", op->name, op); op_unbind->params["dim"] = dim; op_unbind->inputs.push_back(op_in); @@ -105,7 +105,7 @@ void fuse_select_to_unbind(Graph& graph) for (int j = 0; j < select_dimsize; j++) { - graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), select_n_ops[j])); + graph->ops.erase(std::find(graph->ops.begin(), graph->ops.end(), select_n_ops[j])); delete select_n_ops[j]; } diff --git a/tools/pnnx/src/pass_level5/fuse_select_to_unbind.h b/tools/pnnx/src/pass_level5/fuse_select_to_unbind.h index b48e644329e..09d601f4096 100644 --- a/tools/pnnx/src/pass_level5/fuse_select_to_unbind.h +++ b/tools/pnnx/src/pass_level5/fuse_select_to_unbind.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_select_to_unbind(Graph& graph); +void fuse_select_to_unbind(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_slice_copy.cpp b/tools/pnnx/src/pass_level5/fuse_slice_copy.cpp index eccae6f1154..a1a535a139b 100644 --- a/tools/pnnx/src/pass_level5/fuse_slice_copy.cpp +++ b/tools/pnnx/src/pass_level5/fuse_slice_copy.cpp @@ -21,15 +21,15 @@ namespace pnnx { -void fuse_slice_copy(Graph& graph) +void fuse_slice_copy(std::shared_ptr graph) { while (1) { bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "Tensor.copy") continue; @@ -105,13 +105,13 @@ void fuse_slice_copy(Graph& graph) out->producer = 0; out->consumers.clear(); - graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), out)); + graph->operands.erase(std::find(graph->operands.begin(), graph->operands.end(), out)); delete out; op->inputs.clear(); op->outputs.clear(); - graph.ops.erase(graph.ops.begin() + i); + graph->ops.erase(graph->ops.begin() + i); delete op; break; @@ -122,8 +122,8 @@ void fuse_slice_copy(Graph& graph) op->type = "Tensor.slice_copy"; // insert clone just after the producer - Operator* op_clone = graph.new_operator_after("Tensor.clone", op->name + "_ncnnclone", top_sop->inputs[0]->producer); - Operand* clone_out = graph.new_operand(op->name + "_ncnnclone_out"); + Operator* op_clone = graph->new_operator_after("Tensor.clone", op->name + "_ncnnclone", top_sop->inputs[0]->producer); + Operand* clone_out = graph->new_operand(op->name + "_ncnnclone_out"); clone_out->type = top_sop->inputs[0]->type; clone_out->shape = top_sop->inputs[0]->shape; diff --git a/tools/pnnx/src/pass_level5/fuse_slice_copy.h b/tools/pnnx/src/pass_level5/fuse_slice_copy.h index db3aef77359..cd3894816c7 100644 --- a/tools/pnnx/src/pass_level5/fuse_slice_copy.h +++ b/tools/pnnx/src/pass_level5/fuse_slice_copy.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_slice_copy(Graph& graph); +void fuse_slice_copy(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_slice_indices.cpp b/tools/pnnx/src/pass_level5/fuse_slice_indices.cpp index 332ca6b576d..2da23995b89 100644 --- a/tools/pnnx/src/pass_level5/fuse_slice_indices.cpp +++ b/tools/pnnx/src/pass_level5/fuse_slice_indices.cpp @@ -22,15 +22,15 @@ namespace pnnx { -void fuse_slice_indices(Graph& graph) +void fuse_slice_indices(std::shared_ptr graph) { while (1) { bool matched = false; - for (int i = (int)graph.ops.size() - 1; i >= 0; i--) + for (int i = (int)graph->ops.size() - 1; i >= 0; i--) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "Tensor.slice" && op->type != "Tensor.select") continue; @@ -158,10 +158,10 @@ void fuse_slice_indices(Graph& graph) Operator* op_ends = 0; Operator* op_steps = 0; Operator* op_selects = 0; - if (!static_starts) op_starts = graph.new_operator_before("pnnx.SliceIndexes", op->name + "_ncnnstarts", op); - if (!static_ends) op_ends = graph.new_operator_before("pnnx.SliceIndexes", op->name + "_ncnnends", op); - if (!static_steps) op_steps = graph.new_operator_before("pnnx.SliceIndexes", op->name + "_ncnnsteps", op); - if (!static_selects) op_selects = graph.new_operator_before("pnnx.SliceIndexes", op->name + "_ncnnselects", op); + if (!static_starts) op_starts = graph->new_operator_before("pnnx.SliceIndexes", op->name + "_ncnnstarts", op); + if (!static_ends) op_ends = graph->new_operator_before("pnnx.SliceIndexes", op->name + "_ncnnends", op); + if (!static_steps) op_steps = graph->new_operator_before("pnnx.SliceIndexes", op->name + "_ncnnsteps", op); + if (!static_selects) op_selects = graph->new_operator_before("pnnx.SliceIndexes", op->name + "_ncnnselects", op); std::vector starts_indexes; std::vector ends_indexes; @@ -338,11 +338,11 @@ void fuse_slice_indices(Graph& graph) // drop sop and sop output Operand* sop_out = sop->outputs[0]; - graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), sop_out)); + graph->operands.erase(std::find(graph->operands.begin(), graph->operands.end(), sop_out)); delete sop_out; - graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), sop)); + graph->ops.erase(std::find(graph->ops.begin(), graph->ops.end(), sop)); delete sop; } @@ -527,7 +527,7 @@ void fuse_slice_indices(Graph& graph) { op_starts->params["indexes"] = starts_indexes; - Operand* starts_out = graph.new_operand(op->name + "_ncnnstarts_out"); + Operand* starts_out = graph->new_operand(op->name + "_ncnnstarts_out"); starts_out->producer = op_starts; op_starts->outputs.push_back(starts_out); starts_out->consumers.push_back(op); @@ -543,7 +543,7 @@ void fuse_slice_indices(Graph& graph) { op_ends->params["indexes"] = ends_indexes; - Operand* ends_out = graph.new_operand(op->name + "_ncnnends_out"); + Operand* ends_out = graph->new_operand(op->name + "_ncnnends_out"); ends_out->producer = op_ends; op_ends->outputs.push_back(ends_out); ends_out->consumers.push_back(op); @@ -559,7 +559,7 @@ void fuse_slice_indices(Graph& graph) { op_steps->params["indexes"] = steps_indexes; - Operand* steps_out = graph.new_operand(op->name + "_ncnnsteps_out"); + Operand* steps_out = graph->new_operand(op->name + "_ncnnsteps_out"); steps_out->producer = op_steps; op_steps->outputs.push_back(steps_out); steps_out->consumers.push_back(op); @@ -575,7 +575,7 @@ void fuse_slice_indices(Graph& graph) { op_selects->params["indexes"] = selects_indexes; - Operand* selects_out = graph.new_operand(op->name + "_ncnnselects_out"); + Operand* selects_out = graph->new_operand(op->name + "_ncnnselects_out"); selects_out->producer = op_selects; op_selects->outputs.push_back(selects_out); selects_out->consumers.push_back(op); diff --git a/tools/pnnx/src/pass_level5/fuse_slice_indices.h b/tools/pnnx/src/pass_level5/fuse_slice_indices.h index ca8dfa32da8..99d028c8260 100644 --- a/tools/pnnx/src/pass_level5/fuse_slice_indices.h +++ b/tools/pnnx/src/pass_level5/fuse_slice_indices.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_slice_indices(Graph& graph); +void fuse_slice_indices(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_slice_to_tensor_split.cpp b/tools/pnnx/src/pass_level5/fuse_slice_to_tensor_split.cpp index 4e767a02d00..7a92d1b8fa7 100644 --- a/tools/pnnx/src/pass_level5/fuse_slice_to_tensor_split.cpp +++ b/tools/pnnx/src/pass_level5/fuse_slice_to_tensor_split.cpp @@ -20,15 +20,15 @@ namespace pnnx { -void fuse_slice_to_tensor_split(Graph& graph) +void fuse_slice_to_tensor_split(std::shared_ptr graph) { while (1) { bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "Tensor.slice") continue; @@ -99,7 +99,7 @@ void fuse_slice_to_tensor_split(Graph& graph) if (!op2) break; - if (std::find(graph.ops.begin(), graph.ops.end(), op2) < std::find(graph.ops.begin(), graph.ops.end(), cur)) + if (std::find(graph->ops.begin(), graph->ops.end(), op2) < std::find(graph->ops.begin(), graph->ops.end(), cur)) cur = op2; int end2 = op2->params.at("ends").ai[0]; @@ -128,7 +128,7 @@ void fuse_slice_to_tensor_split(Graph& graph) matched = true; // delete all slice ops and replace with tensor_split - Operator* op_tensor_split = graph.new_operator_before("torch.tensor_split", op->name, cur); + Operator* op_tensor_split = graph->new_operator_before("torch.tensor_split", op->name, cur); op_tensor_split->params["dim"] = dim; op_tensor_split->params["indices"] = tensor_split_indices; @@ -148,7 +148,7 @@ void fuse_slice_to_tensor_split(Graph& graph) for (size_t j = 0; j < slice_n_ops.size(); j++) { - graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), slice_n_ops[j])); + graph->ops.erase(std::find(graph->ops.begin(), graph->ops.end(), slice_n_ops[j])); delete slice_n_ops[j]; } diff --git a/tools/pnnx/src/pass_level5/fuse_slice_to_tensor_split.h b/tools/pnnx/src/pass_level5/fuse_slice_to_tensor_split.h index 1c172838bc1..5f6b1e7ac53 100644 --- a/tools/pnnx/src/pass_level5/fuse_slice_to_tensor_split.h +++ b/tools/pnnx/src/pass_level5/fuse_slice_to_tensor_split.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_slice_to_tensor_split(Graph& graph); +void fuse_slice_to_tensor_split(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_static_batchnorm.cpp b/tools/pnnx/src/pass_level5/fuse_static_batchnorm.cpp index 1a3363b603f..ea81a4feedc 100644 --- a/tools/pnnx/src/pass_level5/fuse_static_batchnorm.cpp +++ b/tools/pnnx/src/pass_level5/fuse_static_batchnorm.cpp @@ -195,7 +195,7 @@ pnnx.Output output 1 0 out } }; -void fuse_static_batchnorm(Graph& graph) +void fuse_static_batchnorm(std::shared_ptr graph) { fuse_static_Fbatchnorm_pass_1d a; fuse_static_Fbatchnorm_pass_2d b; diff --git a/tools/pnnx/src/pass_level5/fuse_static_batchnorm.h b/tools/pnnx/src/pass_level5/fuse_static_batchnorm.h index 7ffc7ca2ce8..c6103b8e6a9 100644 --- a/tools/pnnx/src/pass_level5/fuse_static_batchnorm.h +++ b/tools/pnnx/src/pass_level5/fuse_static_batchnorm.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_static_batchnorm(Graph& graph); +void fuse_static_batchnorm(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_static_conv.cpp b/tools/pnnx/src/pass_level5/fuse_static_conv.cpp index 4dda5006d90..6c29f10d051 100644 --- a/tools/pnnx/src/pass_level5/fuse_static_conv.cpp +++ b/tools/pnnx/src/pass_level5/fuse_static_conv.cpp @@ -345,7 +345,7 @@ pnnx.Output output 1 0 out } }; -void fuse_static_conv(Graph& graph) +void fuse_static_conv(std::shared_ptr graph) { fuse_static_Fconv1d_pass_3 a3; fuse_static_Fconv2d_pass_3 a4; diff --git a/tools/pnnx/src/pass_level5/fuse_static_conv.h b/tools/pnnx/src/pass_level5/fuse_static_conv.h index b6bf6c0aeb1..35847545fc2 100644 --- a/tools/pnnx/src/pass_level5/fuse_static_conv.h +++ b/tools/pnnx/src/pass_level5/fuse_static_conv.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_static_conv(Graph& graph); +void fuse_static_conv(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_static_convtranspose.cpp b/tools/pnnx/src/pass_level5/fuse_static_convtranspose.cpp index 5d6aa66f3c2..8c2524abe77 100644 --- a/tools/pnnx/src/pass_level5/fuse_static_convtranspose.cpp +++ b/tools/pnnx/src/pass_level5/fuse_static_convtranspose.cpp @@ -240,7 +240,7 @@ pnnx.Output output 1 0 out } }; -void fuse_static_convtranspose(Graph& graph) +void fuse_static_convtranspose(std::shared_ptr graph) { fuse_static_Fconvtranspose1d_pass a; fuse_static_Fconvtranspose1d_pass_2 b; diff --git a/tools/pnnx/src/pass_level5/fuse_static_convtranspose.h b/tools/pnnx/src/pass_level5/fuse_static_convtranspose.h index 2474074a150..48d808af61d 100644 --- a/tools/pnnx/src/pass_level5/fuse_static_convtranspose.h +++ b/tools/pnnx/src/pass_level5/fuse_static_convtranspose.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_static_convtranspose(Graph& graph); +void fuse_static_convtranspose(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_static_embedding.cpp b/tools/pnnx/src/pass_level5/fuse_static_embedding.cpp index f5ad240f29d..5b3e59ded83 100644 --- a/tools/pnnx/src/pass_level5/fuse_static_embedding.cpp +++ b/tools/pnnx/src/pass_level5/fuse_static_embedding.cpp @@ -43,7 +43,7 @@ pnnx.Output output 1 0 out } }; -void fuse_static_embedding(Graph& graph) +void fuse_static_embedding(std::shared_ptr graph) { fuse_static_Fembedding_pass a; int opindex = 0; diff --git a/tools/pnnx/src/pass_level5/fuse_static_embedding.h b/tools/pnnx/src/pass_level5/fuse_static_embedding.h index 3e53c86653a..89d76e81056 100644 --- a/tools/pnnx/src/pass_level5/fuse_static_embedding.h +++ b/tools/pnnx/src/pass_level5/fuse_static_embedding.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_static_embedding(Graph& graph); +void fuse_static_embedding(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_static_groupnorm.cpp b/tools/pnnx/src/pass_level5/fuse_static_groupnorm.cpp index da0d6112bcf..85426679c02 100644 --- a/tools/pnnx/src/pass_level5/fuse_static_groupnorm.cpp +++ b/tools/pnnx/src/pass_level5/fuse_static_groupnorm.cpp @@ -47,7 +47,7 @@ pnnx.Output output 1 0 out } }; -void fuse_static_groupnorm(Graph& graph) +void fuse_static_groupnorm(std::shared_ptr graph) { fuse_static_Fgroupnorm_pass a; int opindex = 0; diff --git a/tools/pnnx/src/pass_level5/fuse_static_groupnorm.h b/tools/pnnx/src/pass_level5/fuse_static_groupnorm.h index 2de65fa307b..60556260162 100644 --- a/tools/pnnx/src/pass_level5/fuse_static_groupnorm.h +++ b/tools/pnnx/src/pass_level5/fuse_static_groupnorm.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_static_groupnorm(Graph& graph); +void fuse_static_groupnorm(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_static_instancenorm.cpp b/tools/pnnx/src/pass_level5/fuse_static_instancenorm.cpp index 76543b34c5b..a872811ae34 100644 --- a/tools/pnnx/src/pass_level5/fuse_static_instancenorm.cpp +++ b/tools/pnnx/src/pass_level5/fuse_static_instancenorm.cpp @@ -105,7 +105,7 @@ pnnx.Output output 1 0 out } }; -void fuse_static_instancenorm(Graph& graph) +void fuse_static_instancenorm(std::shared_ptr graph) { fuse_static_Finstancenorm_pass_1d a; fuse_static_Finstancenorm_pass_2d b; diff --git a/tools/pnnx/src/pass_level5/fuse_static_instancenorm.h b/tools/pnnx/src/pass_level5/fuse_static_instancenorm.h index df71b0e52a7..1ab4a45a1de 100644 --- a/tools/pnnx/src/pass_level5/fuse_static_instancenorm.h +++ b/tools/pnnx/src/pass_level5/fuse_static_instancenorm.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_static_instancenorm(Graph& graph); +void fuse_static_instancenorm(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_static_layernorm.cpp b/tools/pnnx/src/pass_level5/fuse_static_layernorm.cpp index 0b1f0dc4179..78677203dc9 100644 --- a/tools/pnnx/src/pass_level5/fuse_static_layernorm.cpp +++ b/tools/pnnx/src/pass_level5/fuse_static_layernorm.cpp @@ -47,7 +47,7 @@ pnnx.Output output 1 0 out } }; -void fuse_static_layernorm(Graph& graph) +void fuse_static_layernorm(std::shared_ptr graph) { fuse_static_Flayernorm_pass a; int opindex = 0; diff --git a/tools/pnnx/src/pass_level5/fuse_static_layernorm.h b/tools/pnnx/src/pass_level5/fuse_static_layernorm.h index e61f254d2b5..09de1098a17 100644 --- a/tools/pnnx/src/pass_level5/fuse_static_layernorm.h +++ b/tools/pnnx/src/pass_level5/fuse_static_layernorm.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_static_layernorm(Graph& graph); +void fuse_static_layernorm(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_static_linear.cpp b/tools/pnnx/src/pass_level5/fuse_static_linear.cpp index 7396142d461..a21cb7749aa 100644 --- a/tools/pnnx/src/pass_level5/fuse_static_linear.cpp +++ b/tools/pnnx/src/pass_level5/fuse_static_linear.cpp @@ -108,7 +108,7 @@ pnnx.Output output 1 0 out } }; -void fuse_static_linear(Graph& graph) +void fuse_static_linear(std::shared_ptr graph) { fuse_static_Flinear_pass_3 a3; diff --git a/tools/pnnx/src/pass_level5/fuse_static_linear.h b/tools/pnnx/src/pass_level5/fuse_static_linear.h index 8c26f924c16..bf64604698e 100644 --- a/tools/pnnx/src/pass_level5/fuse_static_linear.h +++ b/tools/pnnx/src/pass_level5/fuse_static_linear.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_static_linear(Graph& graph); +void fuse_static_linear(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/normalize_einsum_equation.cpp b/tools/pnnx/src/pass_level5/normalize_einsum_equation.cpp index 0f1902ab921..60cc64bbc97 100644 --- a/tools/pnnx/src/pass_level5/normalize_einsum_equation.cpp +++ b/tools/pnnx/src/pass_level5/normalize_einsum_equation.cpp @@ -32,11 +32,11 @@ static void replaceAll(std::string& str, const std::string& from, const std::str } } -void normalize_einsum_equation(Graph& graph) +void normalize_einsum_equation(std::shared_ptr graph) { - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "torch.einsum") continue; diff --git a/tools/pnnx/src/pass_level5/normalize_einsum_equation.h b/tools/pnnx/src/pass_level5/normalize_einsum_equation.h index 30cda5d7b54..dd4b42e4878 100644 --- a/tools/pnnx/src/pass_level5/normalize_einsum_equation.h +++ b/tools/pnnx/src/pass_level5/normalize_einsum_equation.h @@ -16,6 +16,6 @@ namespace pnnx { -void normalize_einsum_equation(Graph& graph); +void normalize_einsum_equation(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/unroll_rnn_op.cpp b/tools/pnnx/src/pass_level5/unroll_rnn_op.cpp index 2fda0242309..dd2ff421cc5 100644 --- a/tools/pnnx/src/pass_level5/unroll_rnn_op.cpp +++ b/tools/pnnx/src/pass_level5/unroll_rnn_op.cpp @@ -18,15 +18,15 @@ namespace pnnx { -void unroll_rnn_op(Graph& graph) +void unroll_rnn_op(std::shared_ptr graph) { while (1) { bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "nn.RNN" && op->type != "nn.LSTM" && op->type != "nn.GRU") continue; @@ -56,7 +56,7 @@ void unroll_rnn_op(Graph& graph) { std::string opname = op->name + "_chunk_in_hidden"; - Operator* op1 = graph.new_operator_before("torch.chunk", opname, op); + Operator* op1 = graph->new_operator_before("torch.chunk", opname, op); op1->params["chunks"] = num_layers; op1->params["dim"] = 0; @@ -67,7 +67,7 @@ void unroll_rnn_op(Graph& graph) for (int j = 0; j < num_layers; j++) { - Operand* r0 = graph.new_operand(op1->name + "_in_hidden_" + std::to_string(j)); + Operand* r0 = graph->new_operand(op1->name + "_in_hidden_" + std::to_string(j)); r0->producer = op1; op1->outputs.push_back(r0); @@ -78,7 +78,7 @@ void unroll_rnn_op(Graph& graph) { std::string opname = op->name + "_chunk_in_cell"; - Operator* op1 = graph.new_operator_before("torch.chunk", opname, op); + Operator* op1 = graph->new_operator_before("torch.chunk", opname, op); op1->params["chunks"] = num_layers; op1->params["dim"] = 0; @@ -89,7 +89,7 @@ void unroll_rnn_op(Graph& graph) for (int j = 0; j < num_layers; j++) { - Operand* r0 = graph.new_operand(op1->name + "_in_cell_" + std::to_string(j)); + Operand* r0 = graph->new_operand(op1->name + "_in_cell_" + std::to_string(j)); r0->producer = op1; op1->outputs.push_back(r0); @@ -103,7 +103,7 @@ void unroll_rnn_op(Graph& graph) { std::string opname = op->name + "_unroll_" + std::to_string(j); - Operator* op1 = graph.new_operator_before(op->type, opname, op); + Operator* op1 = graph->new_operator_before(op->type, opname, op); op1->params = op->params; op1->params["num_layers"] = 1; @@ -148,14 +148,14 @@ void unroll_rnn_op(Graph& graph) } else { - Operand* r0 = graph.new_operand(op1->name + "_out"); + Operand* r0 = graph->new_operand(op1->name + "_out"); r0->producer = op1; op1->outputs.push_back(r0); } if (has_output_hidden) { - Operand* r1 = graph.new_operand(op1->name + "_out_hidden"); + Operand* r1 = graph->new_operand(op1->name + "_out_hidden"); r1->producer = op1; op1->outputs.push_back(r1); @@ -163,7 +163,7 @@ void unroll_rnn_op(Graph& graph) } if (has_output_cell) { - Operand* r1 = graph.new_operand(op1->name + "_out_cell"); + Operand* r1 = graph->new_operand(op1->name + "_out_cell"); r1->producer = op1; op1->outputs.push_back(r1); @@ -209,7 +209,7 @@ void unroll_rnn_op(Graph& graph) { std::string opname = op->name + "_cat_out_hidden"; - Operator* op1 = graph.new_operator_before("torch.cat", opname, op); + Operator* op1 = graph->new_operator_before("torch.cat", opname, op); op1->params["dim"] = 0; @@ -227,7 +227,7 @@ void unroll_rnn_op(Graph& graph) { std::string opname = op->name + "_cat_out_cell"; - Operator* op1 = graph.new_operator_before("torch.cat", opname, op); + Operator* op1 = graph->new_operator_before("torch.cat", opname, op); op1->params["dim"] = 0; @@ -245,7 +245,7 @@ void unroll_rnn_op(Graph& graph) op->inputs.clear(); op->outputs.clear(); - graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op)); + graph->ops.erase(std::find(graph->ops.begin(), graph->ops.end(), op)); delete op; diff --git a/tools/pnnx/src/pass_level5/unroll_rnn_op.h b/tools/pnnx/src/pass_level5/unroll_rnn_op.h index a3d57a84f04..abf8637045b 100644 --- a/tools/pnnx/src/pass_level5/unroll_rnn_op.h +++ b/tools/pnnx/src/pass_level5/unroll_rnn_op.h @@ -16,6 +16,6 @@ namespace pnnx { -void unroll_rnn_op(Graph& graph); +void unroll_rnn_op(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level6.cpp b/tools/pnnx/src/pass_level6.cpp index afd136b5422..8a680e54df8 100644 --- a/tools/pnnx/src/pass_level6.cpp +++ b/tools/pnnx/src/pass_level6.cpp @@ -19,10 +19,11 @@ #include "pass_level6/trans_Stack2Unsqueeze.h" #include "pass_level6/trans_ReshapeAs2Reshape.h" #include "pass_level6/trans_TensorTypeAs2TensorTo.h" - +#include "pass_level6/fold_Loop.h" +#include "config.h" namespace pnnx { -void pass_level6(Graph& g, const std::set& foldable_constants, const std::string& foldable_constants_zippath) +void pass_level6(std::shared_ptr g, const std::set& foldable_constants, const std::string& foldable_constants_zippath) { eliminate_ListUnpack(g); fprintf(stderr, "############# finish eliminate_ListUnpack\n"); @@ -30,10 +31,16 @@ void pass_level6(Graph& g, const std::set& foldable_constants, cons fprintf(stderr, "############# finish trans_expression2TupleConstruct\n"); trans_Stack2Unsqueeze(g); fprintf(stderr, "############# finish trans_Stack2Unsqueeze\n"); - trans_ReshapeAs2Reshape(g); - fprintf(stderr, "############# finish trans_ReshapeAs2Reshape\n"); + if(!dynamic_network) + { + trans_ReshapeAs2Reshape(g); + fprintf(stderr, "############# finish trans_ReshapeAs2Reshape\n"); + } + trans_TensorTypeAs2TensorTo(g); fprintf(stderr, "############# finish trans_TensorTypeAs2TensorTo\n"); + fold_Loop(g); + fprintf(stderr, "############# finish fold_Loop\n"); } } // namespace pnnx diff --git a/tools/pnnx/src/pass_level6.h b/tools/pnnx/src/pass_level6.h index 49d1ef8d201..0ce9cf8ffc2 100644 --- a/tools/pnnx/src/pass_level6.h +++ b/tools/pnnx/src/pass_level6.h @@ -19,7 +19,7 @@ namespace pnnx { -void pass_level6(Graph& g, const std::set& foldable_constants, const std::string& foldable_constants_zippath); +void pass_level6(std::shared_ptr g, const std::set& foldable_constants, const std::string& foldable_constants_zippath); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level6/eliminate_ListUnpack.cpp b/tools/pnnx/src/pass_level6/eliminate_ListUnpack.cpp index 7a4a889f5f9..c7e2af1f8e4 100644 --- a/tools/pnnx/src/pass_level6/eliminate_ListUnpack.cpp +++ b/tools/pnnx/src/pass_level6/eliminate_ListUnpack.cpp @@ -19,15 +19,15 @@ namespace pnnx { -void eliminate_ListUnpack(Graph& graph) +void eliminate_ListUnpack(std::shared_ptr graph) { while (1) { bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "prim::ListUnpack") continue; @@ -45,13 +45,13 @@ void eliminate_ListUnpack(Graph& graph) } ListUnpack_input->producer = 0; ListUnpack_input->consumers.clear(); - graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), ListUnpack_input)); + graph->operands.erase(std::find(graph->operands.begin(), graph->operands.end(), ListUnpack_input)); delete ListUnpack_input; op->inputs.clear(); op->outputs.clear(); - graph.ops.erase(graph.ops.begin() + i); + graph->ops.erase(graph->ops.begin() + i); delete op; break; diff --git a/tools/pnnx/src/pass_level6/eliminate_ListUnpack.h b/tools/pnnx/src/pass_level6/eliminate_ListUnpack.h index b9bf2710b0e..5b69ffe2381 100644 --- a/tools/pnnx/src/pass_level6/eliminate_ListUnpack.h +++ b/tools/pnnx/src/pass_level6/eliminate_ListUnpack.h @@ -16,6 +16,6 @@ namespace pnnx { -void eliminate_ListUnpack(Graph& graph); +void eliminate_ListUnpack(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level6/fold_Loop.cpp b/tools/pnnx/src/pass_level6/fold_Loop.cpp new file mode 100644 index 00000000000..faa4e411e4f --- /dev/null +++ b/tools/pnnx/src/pass_level6/fold_Loop.cpp @@ -0,0 +1,93 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fold_Loop.h" + +#include +#include "pass_level2.h" + +namespace pnnx { + +void fold_Loop(std::shared_ptr graph) +{ + while (1) + { + bool matched = false; + + for (size_t i = 0; i < graph->ops.size(); i++) + { + Operator* op = graph->ops[i]; + + if (op->type != "prim::Loop") + continue; + op->type = "pnnx.Loop"; + // delete prim::Loop + matched = true; + Operand* loop_iterNum_input = op->inputs[0]; + Operand* loop_condition_input = op->inputs[1]; + + op->inputs.erase(op->inputs.begin()); + op->inputs.erase(op->inputs.begin()); + // parse iterNum only used for static + // [todo] dynamic + Operator* loop_iterNum_expression = loop_iterNum_input->producer; + std::string iterNum_expr = loop_iterNum_expression->params["expr"].s; + // check pre_node or not + if(loop_iterNum_expression->inputs.size() == 0) + { + int iter_num = std::stoi(iterNum_expr); + op->params["iter_num"] = iter_num; + } + else{ + Operand* pre_loop_iterNum_input = loop_iterNum_expression->inputs[0]; + op->params["iter_num"] = pre_loop_iterNum_input->shape[0]; + pre_loop_iterNum_input->consumers.erase(std::find(pre_loop_iterNum_input->consumers.begin(), pre_loop_iterNum_input->consumers.end(), loop_iterNum_expression)); + loop_iterNum_expression->inputs.clear(); + } + // delete iterNum expression + loop_iterNum_input->producer = 0; + loop_iterNum_input->consumers.clear(); + graph->operands.erase(std::find(graph->operands.begin(), graph->operands.end(), loop_iterNum_input)); + delete loop_iterNum_input; + + loop_iterNum_expression->inputs.clear(); + loop_iterNum_expression->outputs.clear(); + graph->ops.erase(std::find(graph->ops.begin(), graph->ops.end(), loop_iterNum_expression)); + delete loop_iterNum_expression; + + // parse condition + Operator* loop_condition_expression = loop_condition_input->producer; + std::string condition_expr = loop_condition_expression->params["expr"].s; + op->params["condition"] = condition_expr; + + // delete condition expression + loop_condition_input->producer = 0; + loop_condition_input->consumers.clear(); + graph->operands.erase(std::find(graph->operands.begin(), graph->operands.end(), loop_condition_input)); + delete loop_condition_input; + + loop_condition_expression->inputs.clear(); + loop_condition_expression->outputs.clear(); + graph->ops.erase(std::find(graph->ops.begin(), graph->ops.end(), loop_condition_expression)); + delete loop_condition_expression; + + break; + } + + if (!matched) + break; + } +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level6/fold_Loop.h b/tools/pnnx/src/pass_level6/fold_Loop.h new file mode 100644 index 00000000000..50d6ca83070 --- /dev/null +++ b/tools/pnnx/src/pass_level6/fold_Loop.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fold_Loop(std::shared_ptr graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level6/trans_ReshapeAs2Reshape.cpp b/tools/pnnx/src/pass_level6/trans_ReshapeAs2Reshape.cpp index 66b5483d3bd..de4fcc4e976 100644 --- a/tools/pnnx/src/pass_level6/trans_ReshapeAs2Reshape.cpp +++ b/tools/pnnx/src/pass_level6/trans_ReshapeAs2Reshape.cpp @@ -19,15 +19,15 @@ namespace pnnx { -void trans_ReshapeAs2Reshape(Graph& graph) +void trans_ReshapeAs2Reshape(std::shared_ptr graph) { while (1) { bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "Tensor.reshape_as") continue; @@ -55,11 +55,11 @@ void trans_ReshapeAs2Reshape(Graph& graph) // } // } - // for(int index = 0; index < graph.operands.size(); index++) + // for(int index = 0; index < graph->operands.size(); index++) // { - // if(graph.operands[index]->name == input1->name) + // if(graph->operands[index]->name == input1->name) // { - // graph.operands.erase(graph.operands.begin() + index); + // graph->operands.erase(graph->operands.begin() + index); // break; // } // } @@ -112,12 +112,12 @@ void trans_ReshapeAs2Reshape(Graph& graph) { delete_op->inputs.clear(); delete_op->outputs.clear(); - graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), delete_op)); + graph->ops.erase(std::find(graph->ops.begin(), graph->ops.end(), delete_op)); delete delete_op; } for(auto delete_operand: delete_operands) { - graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), delete_operand)); + graph->operands.erase(std::find(graph->operands.begin(), graph->operands.end(), delete_operand)); delete delete_operand; } break; diff --git a/tools/pnnx/src/pass_level6/trans_ReshapeAs2Reshape.h b/tools/pnnx/src/pass_level6/trans_ReshapeAs2Reshape.h index 893cf19d588..529a7e93fbf 100644 --- a/tools/pnnx/src/pass_level6/trans_ReshapeAs2Reshape.h +++ b/tools/pnnx/src/pass_level6/trans_ReshapeAs2Reshape.h @@ -17,6 +17,6 @@ namespace pnnx { -void trans_ReshapeAs2Reshape(Graph& graph); +void trans_ReshapeAs2Reshape(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level6/trans_Stack2Unsqueeze.cpp b/tools/pnnx/src/pass_level6/trans_Stack2Unsqueeze.cpp index cf74f703969..6858a914aff 100644 --- a/tools/pnnx/src/pass_level6/trans_Stack2Unsqueeze.cpp +++ b/tools/pnnx/src/pass_level6/trans_Stack2Unsqueeze.cpp @@ -19,18 +19,19 @@ namespace pnnx { -void trans_Stack2Unsqueeze(Graph& graph) +void trans_Stack2Unsqueeze(std::shared_ptr graph) { while (1) { bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "torch.stack") continue; + // get input num if( op->inputs.size() == 1) { diff --git a/tools/pnnx/src/pass_level6/trans_Stack2Unsqueeze.h b/tools/pnnx/src/pass_level6/trans_Stack2Unsqueeze.h index 777c00e7bd7..4b204d8976f 100644 --- a/tools/pnnx/src/pass_level6/trans_Stack2Unsqueeze.h +++ b/tools/pnnx/src/pass_level6/trans_Stack2Unsqueeze.h @@ -17,6 +17,6 @@ namespace pnnx { -void trans_Stack2Unsqueeze(Graph& graph); +void trans_Stack2Unsqueeze(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level6/trans_TensorTypeAs2TensorTo.cpp b/tools/pnnx/src/pass_level6/trans_TensorTypeAs2TensorTo.cpp index 0dbb1e61bfe..4a55f3d6e6e 100644 --- a/tools/pnnx/src/pass_level6/trans_TensorTypeAs2TensorTo.cpp +++ b/tools/pnnx/src/pass_level6/trans_TensorTypeAs2TensorTo.cpp @@ -19,15 +19,15 @@ namespace pnnx { -void trans_TensorTypeAs2TensorTo(Graph& graph) +void trans_TensorTypeAs2TensorTo(std::shared_ptr graph) { while (1) { bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; + Operator* op = graph->ops[i]; if (op->type != "Tensor.type_as") continue; @@ -101,12 +101,12 @@ void trans_TensorTypeAs2TensorTo(Graph& graph) { delete_op->inputs.clear(); delete_op->outputs.clear(); - graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), delete_op)); + graph->ops.erase(std::find(graph->ops.begin(), graph->ops.end(), delete_op)); delete delete_op; } for(auto delete_operand: delete_operands) { - graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), delete_operand)); + graph->operands.erase(std::find(graph->operands.begin(), graph->operands.end(), delete_operand)); delete delete_operand; } break; diff --git a/tools/pnnx/src/pass_level6/trans_TensorTypeAs2TensorTo.h b/tools/pnnx/src/pass_level6/trans_TensorTypeAs2TensorTo.h index 245e3bb0c48..9347de05848 100644 --- a/tools/pnnx/src/pass_level6/trans_TensorTypeAs2TensorTo.h +++ b/tools/pnnx/src/pass_level6/trans_TensorTypeAs2TensorTo.h @@ -17,6 +17,6 @@ namespace pnnx { -void trans_TensorTypeAs2TensorTo(Graph& graph); +void trans_TensorTypeAs2TensorTo(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level6/trans_expression2TupleConstruct.cpp b/tools/pnnx/src/pass_level6/trans_expression2TupleConstruct.cpp index 5f60d7c29c4..257594d9bcf 100644 --- a/tools/pnnx/src/pass_level6/trans_expression2TupleConstruct.cpp +++ b/tools/pnnx/src/pass_level6/trans_expression2TupleConstruct.cpp @@ -19,16 +19,16 @@ namespace pnnx { -void trans_expression2TupleConstruct(Graph& graph) +void trans_expression2TupleConstruct(std::shared_ptr graph) { while (1) { bool matched = false; - for (size_t i = 0; i < graph.ops.size(); i++) + for (size_t i = 0; i < graph->ops.size(); i++) { - Operator* op = graph.ops[i]; - + Operator* op = graph->ops[i]; + if (op->type != "pnnx.Expression") continue; // get expr @@ -67,13 +67,13 @@ void trans_expression2TupleConstruct(Graph& graph) } input->producer = 0; input->consumers.clear(); - graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), input)); + graph->operands.erase(std::find(graph->operands.begin(), graph->operands.end(), input)); delete input; op->inputs.clear(); op->outputs.clear(); - graph.ops.erase(graph.ops.begin() + i); + graph->ops.erase(graph->ops.begin() + i); delete op; } else diff --git a/tools/pnnx/src/pass_level6/trans_expression2TupleConstruct.h b/tools/pnnx/src/pass_level6/trans_expression2TupleConstruct.h index 957717c77b3..0ee9425d8c7 100644 --- a/tools/pnnx/src/pass_level6/trans_expression2TupleConstruct.h +++ b/tools/pnnx/src/pass_level6/trans_expression2TupleConstruct.h @@ -16,6 +16,6 @@ namespace pnnx { -void trans_expression2TupleConstruct(Graph& graph); +void trans_expression2TupleConstruct(std::shared_ptr graph); } // namespace pnnx diff --git a/tools/pnnx/src/py_proj.cpp b/tools/pnnx/src/py_proj.cpp index 470ea0669ed..9a38dfe3f3a 100644 --- a/tools/pnnx/src/py_proj.cpp +++ b/tools/pnnx/src/py_proj.cpp @@ -5,7 +5,7 @@ // #include #define STRINGIFY(x) #x #define MACRO_STRINGIFY(x) STRINGIFY(x) -#define MYLIBRARY_VERSION "dev.1.0.19.20240614" +#define MYLIBRARY_VERSION "dev.1.0.21.20240619" using namespace pnnx_graph; using namespace pnnx_ir; namespace py = pybind11; @@ -72,7 +72,10 @@ PYBIND11_MODULE(ptx, m) //add PnnxGraph class py::class_(m, "PnnxGraph") .def(py::init<>()) - .def("getNvpPnnxModel", &PnnxGraph::getNvpPnnxModel, py::arg("pt_path"), py::arg("input_shape"), py::arg("custom_op_path"), py::arg("custom_op_py"), py::arg("start_nodes") = "", py::arg("end_nodes") = "") + .def("getNvpPnnxModel", &PnnxGraph::getNvpPnnxModel, py::arg("pt_path"), \ + py::arg("save_dir"), py::arg("input_shape"), py::arg("custom_op_path"), \ + py::arg("custom_op_py"), py::arg("start_nodes") = "", py::arg("end_nodes") = "",\ + py::arg("extract_model_name") = "model") .def("loadModel", &PnnxGraph::loadModel) .def("saveModel", &PnnxGraph::saveModel) // .def("getOperators", (std::vector(PnnxGraph::*)()) & PnnxGraph::getOperators) diff --git a/tools/pnnx/tools/export.py b/tools/pnnx/tools/export.py index 7fce439f147..ef2c8aaba07 100644 --- a/tools/pnnx/tools/export.py +++ b/tools/pnnx/tools/export.py @@ -9,6 +9,7 @@ import platform try: import torch + import torchvision import torchvision.models as models import torch.nn as nn import torch.nn.functional as F @@ -36,6 +37,7 @@ def input_torch_type_to_str(tensor): + # 0=null 1=f32 2=f64 3=f16 4=i32 5=i64 6=i16 7=i8 8=u8 9=bool 10=c64 11=c128 12=c32 if tensor.dtype == torch.float32 or tensor.dtype == torch.float: return "f32" if tensor.dtype == torch.float64 or tensor.dtype == torch.double: @@ -115,7 +117,22 @@ def __init__(self,): def forward(self, x): output = self.unfold(x) return output - + +class NMS(nn.Module): + def __init__(self): + super().__init__() + def forward(self, boxes, scores): + x2 = torchvision.ops.nms(boxes, scores, iou_threshold = 0.2) + return x2 + +class Script1(torch.nn.Module): + def __init__(self,): + super(Script1, self).__init__() + + def forward(self, x, y): + for i in range(int(y)): + x = x + y + return x def export(model_name: str, net: Union[nn.Module, str], input_shape, export_onnx: bool): if isinstance(input_shape, list): @@ -171,10 +188,12 @@ def export(model_name: str, net: Union[nn.Module, str], input_shape, export_onnx "stack":stackModel, "oneHot":oneHotModel, "reshape_as": reshape_as_Model, - "unfold":unfold_Model + "unfold":unfold_Model, + "NMS":NMS, + "Script1":Script1 } - model_name = 'unfold' + model_name = 'Script1' if model_name in net_map: net = net_map[model_name]() else: @@ -209,8 +228,13 @@ def export(model_name: str, net: Union[nn.Module, str], input_shape, export_onnx # ---------------------------- # unfold - input_shape = [[1,3,9,9]] - + # input_shape = [[1,3,9,9]] + # input_shape = [[4,4],[4]] + + # Script1 + i1 = torch.ones([5,5]) + i2 = torch.ones(1, dtype=torch.long) + input_shape = [i1,i2] export(model_name, net, input_shape, export_onnx) # import pnnx # pnnx.export