From f8560112975a80dcdac84a7d0bd6284467715a49 Mon Sep 17 00:00:00 2001 From: nihui Date: Tue, 15 Oct 2024 19:28:53 +0800 Subject: [PATCH] pnnx drop onnx weight-like graph input (#5736) --- tools/pnnx/src/CMakeLists.txt | 1 + tools/pnnx/src/load_onnx.cpp | 3 + .../pass_onnx/eliminate_initializer_input.cpp | 74 +++++++++++++++++++ .../pass_onnx/eliminate_initializer_input.h | 25 +++++++ 4 files changed, 103 insertions(+) create mode 100644 tools/pnnx/src/pass_onnx/eliminate_initializer_input.cpp create mode 100644 tools/pnnx/src/pass_onnx/eliminate_initializer_input.h diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 7743a8ae453..2281875dbd4 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -662,6 +662,7 @@ if(onnxruntime_FOUND) set(pnnx_pass_onnx_SRCS pass_onnx/canonicalize.cpp pass_onnx/dead_code_elimination.cpp + pass_onnx/eliminate_initializer_input.cpp pass_onnx/eliminate_noop.cpp pass_onnx/fold_constants.cpp pass_onnx/inline_containers.cpp diff --git a/tools/pnnx/src/load_onnx.cpp b/tools/pnnx/src/load_onnx.cpp index 9adf2b47088..e39c2029659 100644 --- a/tools/pnnx/src/load_onnx.cpp +++ b/tools/pnnx/src/load_onnx.cpp @@ -29,6 +29,7 @@ #include "pass_onnx/canonicalize.h" #include "pass_onnx/dead_code_elimination.h" +#include "pass_onnx/eliminate_initializer_input.h" #include "pass_onnx/eliminate_noop.h" #include "pass_onnx/fold_constants.h" #include "pass_onnx/inline_containers.h" @@ -531,6 +532,8 @@ int load_onnx(const std::string& onnxpath, Graph& pnnx_graph, return -1; } + onnx2pnnx::eliminate_initializer_input(model); + // input shape sanity check if (!check_input_shape(model, input_shapes, input_types)) { diff --git a/tools/pnnx/src/pass_onnx/eliminate_initializer_input.cpp b/tools/pnnx/src/pass_onnx/eliminate_initializer_input.cpp new file mode 100644 index 00000000000..be447bd26da --- /dev/null +++ b/tools/pnnx/src/pass_onnx/eliminate_initializer_input.cpp @@ -0,0 +1,74 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "eliminate_initializer_input.h" + +#include +#include +#include + +namespace pnnx { + +namespace onnx2pnnx { + +void eliminate_initializer_input(onnx::ModelProto& model) +{ + // collect initializers + std::unordered_set initializers; + { + const onnx::GraphProto& graph = model.graph(); + for (int i = 0; i < graph.initializer_size(); i++) + { + initializers.insert(graph.initializer(i).name()); + } + } + + // collect initializer graph input + std::vector initializer_input_indexes; + { + const onnx::GraphProto& graph = model.graph(); + for (int i = 0; i < graph.input_size(); i++) + { + const std::string& input_name = graph.input(i).name(); + if (initializers.find(input_name) == initializers.end()) + continue; + + initializer_input_indexes.push_back(i); + } + } + + // eliminate initializer graph input + { + onnx::GraphProto* graph = model.mutable_graph(); + + for (size_t i = 0; i < initializer_input_indexes.size(); i++) + { + const int initializer_input_index = initializer_input_indexes[i]; + + // ..... iii ....... + const int graph_input_size = graph->input_size(); + for (int j = initializer_input_index; j < graph_input_size - 1; j++) + { + graph->mutable_input()->SwapElements(j, j + 1); + } + + // ..... ....... iii + graph->mutable_input()->RemoveLast(); + } + } +} + +} // namespace onnx2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_onnx/eliminate_initializer_input.h b/tools/pnnx/src/pass_onnx/eliminate_initializer_input.h new file mode 100644 index 00000000000..f82b71cd187 --- /dev/null +++ b/tools/pnnx/src/pass_onnx/eliminate_initializer_input.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "onnx-ml.pb.h" + +namespace pnnx { + +namespace onnx2pnnx { + +void eliminate_initializer_input(onnx::ModelProto& model); + +} // namespace onnx2pnnx + +} // namespace pnnx