From 0c29d7cb3cad65fe7561dad59296c2d5413061aa Mon Sep 17 00:00:00 2001 From: "sen.li" Date: Sat, 11 May 2024 11:48:10 +0800 Subject: [PATCH] Add new pass trans_Stack2Unsqueeze, When using torch.stack with a single input and effectively achieving the same result as torch.unsqueeze --- tools/pnnx/Releasenotes | 5 +- tools/pnnx/src/CMakeLists.txt | 1 + tools/pnnx/src/ir.cpp | 2 +- tools/pnnx/src/pass_level6.cpp | 3 +- .../src/pass_level6/trans_Stack2Unsqueeze.cpp | 57 +++++++++++++++++++ .../src/pass_level6/trans_Stack2Unsqueeze.h | 22 +++++++ tools/pnnx/src/py_proj.cpp | 2 +- 7 files changed, 88 insertions(+), 4 deletions(-) create mode 100644 tools/pnnx/src/pass_level6/trans_Stack2Unsqueeze.cpp create mode 100644 tools/pnnx/src/pass_level6/trans_Stack2Unsqueeze.h diff --git a/tools/pnnx/Releasenotes b/tools/pnnx/Releasenotes index 85dfe09c552..2265078456a 100644 --- a/tools/pnnx/Releasenotes +++ b/tools/pnnx/Releasenotes @@ -20,4 +20,7 @@ dev.1.0.4.20240327 dev.1.0.5.20240508 1. Synchronize the main ncnn repository -2. Fix missing approximate parameters of nn.GELU \ No newline at end of file +2. Fix missing approximate parameters of nn.GELU + +dev.1.0.6.20240511 +1. Add new pass trans_Stack2Unsqueeze, When using torch.stack with a single input and effectively achieving the same result as torch.unsqueeze \ No newline at end of file diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 9d5d80546e3..8bde3c0f8bd 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -380,6 +380,7 @@ set(pnnx_pass_level5_SRCS set(pnnx_pass_level6_SRCS pass_level6/eliminate_ListUnpack.cpp pass_level6/trans_expression2TupleConstruct.cpp + pass_level6/trans_Stack2Unsqueeze.cpp ) set(pnnx_pass_ncnn_SRCS diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 860ac2208a7..c90f0d9b685 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -3296,7 +3296,7 @@ int Graph::python_infer(const std::string& pypath, const std::string& binpath, fprintf(pyfp, "v_%s = %s(", sanitize_identifier(op->outputs[0]->name).c_str(), op->type.c_str()); if (op->inputs.size() == 1) { - fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[0]->name).c_str()); + fprintf(pyfp, "[v_%s]", sanitize_identifier(op->inputs[0]->name).c_str()); } else { diff --git a/tools/pnnx/src/pass_level6.cpp b/tools/pnnx/src/pass_level6.cpp index 734db7e2e1f..f4730a824cc 100644 --- a/tools/pnnx/src/pass_level6.cpp +++ b/tools/pnnx/src/pass_level6.cpp @@ -16,13 +16,14 @@ #include "pass_level6/eliminate_ListUnpack.h" #include "pass_level6/trans_expression2TupleConstruct.h" - +#include "pass_level6/trans_Stack2Unsqueeze.h" namespace pnnx { void pass_level6(Graph& g, const std::set& foldable_constants, const std::string& foldable_constants_zippath) { eliminate_ListUnpack(g); trans_expression2TupleConstruct(g); + trans_Stack2Unsqueeze(g); } } // namespace pnnx diff --git a/tools/pnnx/src/pass_level6/trans_Stack2Unsqueeze.cpp b/tools/pnnx/src/pass_level6/trans_Stack2Unsqueeze.cpp new file mode 100644 index 00000000000..1ee9233d0c9 --- /dev/null +++ b/tools/pnnx/src/pass_level6/trans_Stack2Unsqueeze.cpp @@ -0,0 +1,57 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "trans_Stack2Unsqueeze.h" + +#include +#include "pass_level2.h" + +namespace pnnx { + +void trans_Stack2Unsqueeze(Graph& graph) +{ + while (1) + { + bool matched = false; + + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + + if (op->type != "torch.stack") + continue; + // get input num + if( op->inputs.size() == 1) + { + op->type = "torch.unsqueeze"; + std::string str = op->name; + std::string from = "torch.stack"; + std::string to = "torch.unsqueeze"; + + // to find sub str + size_t start_pos = str.find(from); + if(start_pos != std::string::npos) { + // replace sub str + str.replace(start_pos, from.length(), to); + } + op->name = str; + } + } + + if (!matched) + break; + } +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level6/trans_Stack2Unsqueeze.h b/tools/pnnx/src/pass_level6/trans_Stack2Unsqueeze.h new file mode 100644 index 00000000000..777c00e7bd7 --- /dev/null +++ b/tools/pnnx/src/pass_level6/trans_Stack2Unsqueeze.h @@ -0,0 +1,22 @@ + +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void trans_Stack2Unsqueeze(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/py_proj.cpp b/tools/pnnx/src/py_proj.cpp index 42599381fe7..98f9d75635b 100644 --- a/tools/pnnx/src/py_proj.cpp +++ b/tools/pnnx/src/py_proj.cpp @@ -5,7 +5,7 @@ // #include #define STRINGIFY(x) #x #define MACRO_STRINGIFY(x) STRINGIFY(x) -#define MYLIBRARY_VERSION "dev.1.0.5.20240508" +#define MYLIBRARY_VERSION "dev.1.0.6.20240511" using namespace pnnx_graph; using namespace pnnx_ir; namespace py = pybind11;