Skip to content

Commit

Permalink
1. Add torch,logical_not and torch.nonzero 2. Add fold constants sub …
Browse files Browse the repository at this point in the history
…graph pass in pass level5 3. Add trans unbind2squeeze pass in pass level6 4. Fix tensor.fil 5. Fix torch.unbind
  • Loading branch information
sen.li committed Jul 23, 2024
1 parent f93c18f commit 1c24d7a
Show file tree
Hide file tree
Showing 15 changed files with 513 additions and 20 deletions.
9 changes: 8 additions & 1 deletion tools/pnnx/Releasenotes
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,11 @@ dev.1.0.24.20240712
1. Fix the bug of extract_sub_graph

dev.1.0.25.20240715
1. Support static qunantize for torch.fx mode
1. Support static qunantize for torch.fx mode

dev.1.0.26.20240723
1. Add torch,logical_not and torch.nonzero
2. Add fold constants sub graph pass in pass level5
3. Add trans unbind2squeeze pass in pass level6
4. Fix tensor.fill
5. Fix torch.unbind
4 changes: 4 additions & 0 deletions tools/pnnx/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ set(pnnx_pass_level2_SRCS
pass_level2/torch_index_select.cpp
pass_level2/torch_le.cpp
pass_level2/torch_lgamma.cpp
pass_level2/torch_logical_not.cpp
pass_level2/torch_logsumexp.cpp
pass_level2/torch_lt.cpp
pass_level2/torch_masked_select.cpp
Expand All @@ -248,6 +249,7 @@ set(pnnx_pass_level2_SRCS
pass_level2/torch_mv.cpp
pass_level2/torch_narrow.cpp
pass_level2/torch_ne.cpp
pass_level2/torch_nonzero.cpp
pass_level2/torch_norm.cpp
pass_level2/torch_normal.cpp
pass_level2/torch_ones.cpp
Expand Down Expand Up @@ -376,6 +378,7 @@ set(pnnx_pass_level5_SRCS
pass_level5/fuse_static_linear.cpp
pass_level5/normalize_einsum_equation.cpp
pass_level5/unroll_rnn_op.cpp
pass_level5/fold_constants_sub_graph.cpp
)

# add by senli 20240321
Expand All @@ -385,6 +388,7 @@ set(pnnx_pass_level6_SRCS
pass_level6/trans_Stack2Unsqueeze.cpp
pass_level6/trans_ReshapeAs2Reshape.cpp
pass_level6/trans_TensorTypeAs2TensorTo.cpp
pass_level6/trans_Unbind2Squeeze.cpp
)

set(pnnx_pass_sub_model_SRCS
Expand Down
50 changes: 43 additions & 7 deletions tools/pnnx/src/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,29 @@
#include <list>
namespace pnnx {

static std::vector<std::string> options = {"main", "replace", "delete"};
static void get_op_name_label(std::string& src_str, std::string& name, std::string& label)
{

size_t pos = src_str.find_last_of('_');

if (pos != std::string::npos) {
name = src_str.substr(0, pos);
label = src_str.substr(pos + 1);
auto it = std::find(options.begin(), options.end(), label);
if (it == options.end())
{
name = src_str;
label = "";
}

} else {
name = src_str;
label = "";
}
}


static bool type_is_integer(int type)
{
if (type == 1) return false;
Expand Down Expand Up @@ -735,14 +758,16 @@ int Graph::load(const std::string& parampath, const std::string& binpath)
std::istringstream iss(line);

std::string type;
std::string name;
std::string new_op_name;
int input_count = 0;
int output_count = 0;

iss >> type >> name >> input_count >> output_count;

iss >> type >> new_op_name >> input_count >> output_count;
std::string name;
std::string label;
get_op_name_label(new_op_name, name, label);
Operator* op = new_operator(type, name);

op->label = label;
for (int j = 0; j < input_count; j++)
{
std::string operand_name;
Expand Down Expand Up @@ -825,8 +850,9 @@ int Graph::save(const std::string& parampath, const std::string& binpath)

for (const Operator* op : ops)
{
fprintf(paramfp, "%-24s %-24s %d %d", op->type.c_str(), op->name.c_str(), (int)op->inputs.size(), (int)op->outputs.size());

std::string new_op_name = op->name + "_" + op->label;
fprintf(paramfp, "%-24s %-24s %d %d", op->type.c_str(), new_op_name.c_str(), (int)op->inputs.size(), (int)op->outputs.size());

for (const Operand* oprand : op->inputs)
{
fprintf(paramfp, " %s", oprand->name.c_str());
Expand Down Expand Up @@ -3919,6 +3945,10 @@ int Graph::python_infer(const std::string& pypath, const std::string& binpath,
{
fprintf(pyfp, "torch.tensor(False)");
}
else if(op->type == "Tensor.fill")
{
fprintf(pyfp, "True");
}
else
{
fprintf(pyfp, "None");
Expand Down Expand Up @@ -4048,9 +4078,15 @@ int Graph::python_infer(const std::string& pypath, const std::string& binpath,
}
fprintf(pyfp, ")");
}

}

fprintf(pyfp, ")\n");
fprintf(pyfp, ")");
if(op->outputs.size() == 1 && op->type == "torch.unbind")
{
fprintf(pyfp, "[0]");
}
fprintf(pyfp, "\n");
}
else
{
Expand Down
2 changes: 1 addition & 1 deletion tools/pnnx/src/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ class Operator
// keep std::string typed member the last for cross cxxabi compatibility
std::string type;
std::string name;

std::string label = ""; // main delete replace
std::vector<std::string> inputnames;
std::map<std::string, Parameter> params;
std::map<std::string, Attribute> attrs;
Expand Down
38 changes: 31 additions & 7 deletions tools/pnnx/src/parse/pnnx_ir_parse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// specific language governing permissions and limitations under the License.

#include "pnnx_ir_parse.h"

#include <limits.h>
#include <stdint.h>
#include <string.h>
Expand All @@ -31,6 +30,27 @@
using namespace pnnx;
namespace pnnx_ir {

static std::vector<std::string> options = {"main", "replace", "delete"};
static void get_op_name_label(std::string& src_str, std::string& name, std::string& label)
{

size_t pos = src_str.find_last_of('_');

if (pos != std::string::npos) {
name = src_str.substr(0, pos);
label = src_str.substr(pos + 1);
auto it = std::find(options.begin(), options.end(), label);
if (it == options.end())
{
name = src_str;
label = "";
}

} else {
name = src_str;
label = "";
}
}
static size_t countSubstring(const std::string& str, const std::string& substr) {
size_t count = 0;
size_t pos = 0;
Expand Down Expand Up @@ -765,13 +785,16 @@ int Graph::load(const std::string& parampath, const std::string& binpath)
std::istringstream iss(line);

std::string type;
std::string name;
std::string new_op_name;
int input_count = 0;
int output_count = 0;

iss >> type >> name >> input_count >> output_count;

iss >> type >> new_op_name >> input_count >> output_count;
std::string name;
std::string label;
get_op_name_label(new_op_name, name, label);
Operator* op = new_operator(type, name);
op->label = label;

for (int j = 0; j < input_count; j++)
{
Expand Down Expand Up @@ -855,8 +878,8 @@ int Graph::save(const std::string& parampath, const std::string& binpath)

for (const Operator* op : ops)
{
fprintf(paramfp, "%-24s %-24s %d %d", op->type.c_str(), op->name.c_str(), (int)op->inputs.size(), (int)op->outputs.size());

std::string new_op_name = op->name + "_" + op->label;
fprintf(paramfp, "%-24s %-24s %d %d", op->type.c_str(), new_op_name.c_str(), (int)op->inputs.size(), (int)op->outputs.size());
for (const Operand* oprand : op->inputs)
{
fprintf(paramfp, " %s", oprand->name.c_str());
Expand Down Expand Up @@ -1041,7 +1064,8 @@ int Graph::save_param(const std::string& parampath, const std::vector<Operator>&

for (const Operator op : input_operators)
{
fprintf(paramfp, "%-24s %-24s %d %d", op.type.c_str(), op.name.c_str(), (int)op.inputs.size(), (int)op.outputs.size());
std::string new_op_name = op.name + "_" + op.label;
fprintf(paramfp, "%-24s %-24s %d %d", op.type.c_str(), new_op_name.c_str(), (int)op.inputs.size(), (int)op.outputs.size());

for (const Operand* oprand : op.inputs)
{
Expand Down
3 changes: 2 additions & 1 deletion tools/pnnx/src/parse/pnnx_ir_parse.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <vector>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <ir.h>
namespace py = pybind11;
#if BUILD_PNNX
namespace torch {
Expand Down Expand Up @@ -197,7 +198,7 @@ class Operator
// keep std::string typed member the last for cross cxxabi compatibility
std::string type;
std::string name;

std::string label = ""; // main delete replace
std::vector<std::string> inputnames;
std::map<std::string, Parameter> params;
std::map<std::string, Attribute> attrs;
Expand Down
40 changes: 40 additions & 0 deletions tools/pnnx/src/pass_level2/torch_logical_not.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class torch_logical_not : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input_0 0 1 input
aten::logical_not op_0 1 1 input out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.logical_not";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_logical_not, 20)

} // namespace pnnx
40 changes: 40 additions & 0 deletions tools/pnnx/src/pass_level2/torch_nonzero.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class torch_nonzero : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input_0 0 1 input
aten::nonzero op_0 1 1 input out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.nonzero";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_nonzero, 20)

} // namespace pnnx
3 changes: 2 additions & 1 deletion tools/pnnx/src/pass_level5.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
#include "pass_level4/canonicalize.h"
#include "pass_level3/fuse_index_expression.h"
#include "pass_level5/fuse_pixel_unshuffle.h"

#include "pass_level5/fold_constants_sub_graph.h"
namespace pnnx {

void pass_level5(std::shared_ptr<pnnx::Graph> g, const std::set<std::string>& foldable_constants, const std::string& foldable_constants_zippath)
Expand Down Expand Up @@ -145,6 +145,7 @@ void pass_level5(std::shared_ptr<pnnx::Graph> g, const std::set<std::string>& fo

dead_code_elimination(g);

fold_constants_sub_graph(g);
canonicalize(g);
}

Expand Down
Loading

0 comments on commit 1c24d7a

Please sign in to comment.