Skip to content

Commit

Permalink
Add new pass trans_Stack2Unsqueeze, When using torch.stack with a sin…
Browse files Browse the repository at this point in the history
…gle input and effectively achieving the same result as torch.unsqueeze
  • Loading branch information
sen.li committed May 11, 2024
1 parent 4585c6e commit 0c29d7c
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 4 deletions.
5 changes: 4 additions & 1 deletion tools/pnnx/Releasenotes
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,7 @@ dev.1.0.4.20240327

dev.1.0.5.20240508
1. Synchronize the main ncnn repository
2. Fix missing approximate parameters of nn.GELU
2. Fix missing approximate parameters of nn.GELU

dev.1.0.6.20240511
1. Add new pass trans_Stack2Unsqueeze, When using torch.stack with a single input and effectively achieving the same result as torch.unsqueeze
1 change: 1 addition & 0 deletions tools/pnnx/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ set(pnnx_pass_level5_SRCS
set(pnnx_pass_level6_SRCS
pass_level6/eliminate_ListUnpack.cpp
pass_level6/trans_expression2TupleConstruct.cpp
pass_level6/trans_Stack2Unsqueeze.cpp
)

set(pnnx_pass_ncnn_SRCS
Expand Down
2 changes: 1 addition & 1 deletion tools/pnnx/src/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3296,7 +3296,7 @@ int Graph::python_infer(const std::string& pypath, const std::string& binpath,
fprintf(pyfp, "v_%s = %s(", sanitize_identifier(op->outputs[0]->name).c_str(), op->type.c_str());
if (op->inputs.size() == 1)
{
fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[0]->name).c_str());
fprintf(pyfp, "[v_%s]", sanitize_identifier(op->inputs[0]->name).c_str());
}
else
{
Expand Down
3 changes: 2 additions & 1 deletion tools/pnnx/src/pass_level6.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@

#include "pass_level6/eliminate_ListUnpack.h"
#include "pass_level6/trans_expression2TupleConstruct.h"

#include "pass_level6/trans_Stack2Unsqueeze.h"
namespace pnnx {

void pass_level6(Graph& g, const std::set<std::string>& foldable_constants, const std::string& foldable_constants_zippath)
{
eliminate_ListUnpack(g);
trans_expression2TupleConstruct(g);
trans_Stack2Unsqueeze(g);
}

} // namespace pnnx
57 changes: 57 additions & 0 deletions tools/pnnx/src/pass_level6/trans_Stack2Unsqueeze.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "trans_Stack2Unsqueeze.h"

#include <algorithm>
#include "pass_level2.h"

namespace pnnx {

void trans_Stack2Unsqueeze(Graph& graph)
{
while (1)
{
bool matched = false;

for (size_t i = 0; i < graph.ops.size(); i++)
{
Operator* op = graph.ops[i];

if (op->type != "torch.stack")
continue;
// get input num
if( op->inputs.size() == 1)
{
op->type = "torch.unsqueeze";
std::string str = op->name;
std::string from = "torch.stack";
std::string to = "torch.unsqueeze";

// to find sub str
size_t start_pos = str.find(from);
if(start_pos != std::string::npos) {
// replace sub str
str.replace(start_pos, from.length(), to);
}
op->name = str;
}
}

if (!matched)
break;
}
}

} // namespace pnnx
22 changes: 22 additions & 0 deletions tools/pnnx/src/pass_level6/trans_Stack2Unsqueeze.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@

// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "ir.h"

namespace pnnx {

void trans_Stack2Unsqueeze(Graph& graph);

} // namespace pnnx
2 changes: 1 addition & 1 deletion tools/pnnx/src/py_proj.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
// #include <torch/extension.h>
#define STRINGIFY(x) #x
#define MACRO_STRINGIFY(x) STRINGIFY(x)
#define MYLIBRARY_VERSION "dev.1.0.5.20240508"
#define MYLIBRARY_VERSION "dev.1.0.6.20240511"
using namespace pnnx_graph;
using namespace pnnx_ir;
namespace py = pybind11;
Expand Down

0 comments on commit 0c29d7c

Please sign in to comment.