From 669ee2f2ff520970d5260b9a5c4cbbc2b67e20a4 Mon Sep 17 00:00:00 2001 From: nihui Date: Mon, 24 Jul 2023 17:26:57 +0800 Subject: [PATCH] pnnx update (#4870) Tensor.fill Tensor.index_put Tensor.to Tensor.type_as torch.topk fmod call Tensor member functions with inputnames static shape_as_tensor nn.Linear dynamic bias eliminate noop type_as convert two-dim nn.Linear to ncnn gemm convert torch.stack to ncnn concat+reshape ignore torch einsum path input --- tools/pnnx/src/CMakeLists.txt | 7 + tools/pnnx/src/ir.cpp | 39 +++- .../pnnx/src/pass_level0/shape_inference.cpp | 3 +- tools/pnnx/src/pass_level1/nn_Linear.cpp | 4 +- tools/pnnx/src/pass_level2/Tensor_fill.cpp | 41 ++++ .../pnnx/src/pass_level2/Tensor_index_put.cpp | 43 ++++ .../src/pass_level2/Tensor_masked_fill.cpp | 2 +- tools/pnnx/src/pass_level2/Tensor_to.cpp | 89 ++++++++ tools/pnnx/src/pass_level2/Tensor_type_as.cpp | 41 ++++ tools/pnnx/src/pass_level2/torch_einsum.cpp | 9 +- tools/pnnx/src/pass_level2/torch_topk.cpp | 44 ++++ .../pnnx/src/pass_level3/fuse_expression.cpp | 63 +++++- tools/pnnx/src/pass_level5.cpp | 2 + .../src/pass_level5/eliminate_type_as.cpp | 84 ++++++++ .../pnnx/src/pass_level5/eliminate_type_as.h | 21 ++ .../pnnx/src/pass_level5/eval_expression.cpp | 6 + .../src/pass_level5/fuse_select_to_unbind.cpp | 6 + tools/pnnx/src/pass_ncnn.cpp | 2 + .../src/pass_ncnn/convert_torch_stack.cpp | 91 +++++++++ .../pnnx/src/pass_ncnn/convert_torch_stack.h | 25 +++ .../pnnx/src/pass_ncnn/expand_expression.cpp | 8 +- .../src/pass_ncnn/insert_reshape_linear.cpp | 21 +- tools/pnnx/src/pass_ncnn/nn_Linear.cpp | 192 ++++++++++++++++++ tools/pnnx/tests/CMakeLists.txt | 6 + tools/pnnx/tests/ncnn/CMakeLists.txt | 1 + tools/pnnx/tests/ncnn/test_torch_stack.py | 60 ++++++ tools/pnnx/tests/test_Tensor_fill.py | 57 ++++++ tools/pnnx/tests/test_Tensor_index_put.py | 63 ++++++ tools/pnnx/tests/test_Tensor_to.py | 63 ++++++ tools/pnnx/tests/test_Tensor_type_as.py | 65 ++++++ tools/pnnx/tests/test_pnnx_expression.py | 75 +++++++ tools/pnnx/tests/test_torch_einsum.py | 5 +- tools/pnnx/tests/test_torch_topk.py | 61 ++++++ 33 files changed, 1278 insertions(+), 21 deletions(-) create mode 100644 tools/pnnx/src/pass_level2/Tensor_fill.cpp create mode 100644 tools/pnnx/src/pass_level2/Tensor_index_put.cpp create mode 100644 tools/pnnx/src/pass_level2/Tensor_to.cpp create mode 100644 tools/pnnx/src/pass_level2/Tensor_type_as.cpp create mode 100644 tools/pnnx/src/pass_level2/torch_topk.cpp create mode 100644 tools/pnnx/src/pass_level5/eliminate_type_as.cpp create mode 100644 tools/pnnx/src/pass_level5/eliminate_type_as.h create mode 100644 tools/pnnx/src/pass_ncnn/convert_torch_stack.cpp create mode 100644 tools/pnnx/src/pass_ncnn/convert_torch_stack.h create mode 100644 tools/pnnx/tests/ncnn/test_torch_stack.py create mode 100644 tools/pnnx/tests/test_Tensor_fill.py create mode 100644 tools/pnnx/tests/test_Tensor_index_put.py create mode 100644 tools/pnnx/tests/test_Tensor_to.py create mode 100644 tools/pnnx/tests/test_Tensor_type_as.py create mode 100644 tools/pnnx/tests/test_pnnx_expression.py create mode 100644 tools/pnnx/tests/test_torch_topk.py diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 0b20339a190..c2bc0306f1e 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -180,7 +180,9 @@ set(pnnx_pass_level2_SRCS pass_level2/Tensor_copy.cpp pass_level2/Tensor_expand.cpp pass_level2/Tensor_expand_as.cpp + pass_level2/Tensor_fill.cpp pass_level2/Tensor_index.cpp + pass_level2/Tensor_index_put.cpp pass_level2/Tensor_masked_fill.cpp pass_level2/Tensor_new_empty.cpp pass_level2/Tensor_new_ones.cpp @@ -189,6 +191,8 @@ set(pnnx_pass_level2_SRCS pass_level2/Tensor_reshape.cpp pass_level2/Tensor_select.cpp pass_level2/Tensor_slice.cpp + pass_level2/Tensor_to.cpp + pass_level2/Tensor_type_as.cpp pass_level2/Tensor_view.cpp pass_level2/torch_addmm.cpp pass_level2/torch_amax.cpp @@ -252,6 +256,7 @@ set(pnnx_pass_level2_SRCS pass_level2/torch_sum.cpp pass_level2/torch_permute.cpp pass_level2/torch_tensor_split.cpp + pass_level2/torch_topk.cpp pass_level2/torch_transpose.cpp pass_level2/torch_unbind.cpp pass_level2/torch_unsqueeze.cpp @@ -320,6 +325,7 @@ set(pnnx_pass_level5_SRCS pass_level5/eliminate_noop_slice.cpp pass_level5/eliminate_noop_view_reshape.cpp pass_level5/eliminate_reshape_shape_expression.cpp + pass_level5/eliminate_type_as.cpp pass_level5/eval_expression.cpp pass_level5/fold_constants.cpp pass_level5/fuse_adjacent_reshape.cpp @@ -361,6 +367,7 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/convert_torch_chunk.cpp pass_ncnn/convert_torch_einsum.cpp pass_ncnn/convert_torch_split.cpp + pass_ncnn/convert_torch_stack.cpp pass_ncnn/convert_torch_tensor_split.cpp pass_ncnn/convert_torch_unbind.cpp pass_ncnn/convert_Tensor_select.cpp diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 25a24a302f3..554e08e5786 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -1297,10 +1297,12 @@ static std::string expand_expression(const Operator* op) exprstack.push(r); } else if (t == "atan2" + || t == "fmod" || t == "pow") { std::string binaryop; if (t == "atan2") binaryop = "torch.atan2"; + if (t == "fmod") binaryop = "torch.fmod"; if (t == "pow") binaryop = "torch.pow"; std::string a = exprstack.top(); @@ -1311,7 +1313,7 @@ static std::string expand_expression(const Operator* op) std::string r = binaryop + "(" + a + ", " + b + ")"; exprstack.push(r); } - else if (t == "add" || t == "sub" || t == "mul" || t == "div" || t == "floor_divide" || t == "and" || t == "or" || t == "xor" || t == "lshift" || t == "rshift") + else if (t == "add" || t == "sub" || t == "mul" || t == "div" || t == "floor_divide" || t == "remainder" || t == "and" || t == "or" || t == "xor" || t == "lshift" || t == "rshift") { std::string binaryop; if (t == "add") binaryop = "+"; @@ -1319,6 +1321,7 @@ static std::string expand_expression(const Operator* op) if (t == "mul") binaryop = "*"; if (t == "div") binaryop = "/"; if (t == "floor_divide") binaryop = "//"; + if (t == "remainder") binaryop = "%"; if (t == "and") binaryop = "&"; if (t == "or") binaryop = "|"; if (t == "xor") binaryop = "^"; @@ -2152,11 +2155,39 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) if (op->type.substr(0, 7) == "Tensor.") { - fprintf(pyfp, " = v_%s.%s(", sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str()); + if (op->type == "Tensor.fill") + { + fprintf(pyfp, " = v_%s.fill_(", sanitize_identifier(op->inputs[0]->name).c_str()); + } + else + { + fprintf(pyfp, " = v_%s.%s(", sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str()); + } + + if (op->inputnames.size() == op->inputs.size()) + { + for (size_t i = 1; i < op->inputs.size(); i++) + { + if (!op->inputnames[i].empty()) + continue; - for (size_t i = 1; i < op->inputs.size(); i++) + fprintf(pyfp, "v_%s, ", sanitize_identifier(op->inputs[i]->name).c_str()); + } + + for (size_t i = 1; i < op->inputs.size(); i++) + { + if (op->inputnames[i].empty()) + continue; + + fprintf(pyfp, "%s=v_%s, ", op->inputnames[i].c_str(), sanitize_identifier(op->inputs[i]->name).c_str()); + } + } + else { - fprintf(pyfp, "v_%s, ", sanitize_identifier(op->inputs[i]->name).c_str()); + for (size_t i = 1; i < op->inputs.size(); i++) + { + fprintf(pyfp, "v_%s, ", sanitize_identifier(op->inputs[i]->name).c_str()); + } } } else diff --git a/tools/pnnx/src/pass_level0/shape_inference.cpp b/tools/pnnx/src/pass_level0/shape_inference.cpp index 431fccb07fb..c8319d130e0 100644 --- a/tools/pnnx/src/pass_level0/shape_inference.cpp +++ b/tools/pnnx/src/pass_level0/shape_inference.cpp @@ -39,7 +39,8 @@ static bool value_link_input(const torch::jit::Value* v, const std::vectorparams["in_features"] = weight.size(1); op->params["out_features"] = weight.size(0); - op->params["bias"] = mod.hasattr("bias"); + op->params["bias"] = mod.hasattr("bias") && mod.attr("bias").isTensor(); op->attrs["weight"] = weight; - if (mod.hasattr("bias")) + if (mod.hasattr("bias") && mod.attr("bias").isTensor()) { op->attrs["bias"] = mod.attr("bias").toTensor(); } diff --git a/tools/pnnx/src/pass_level2/Tensor_fill.cpp b/tools/pnnx/src/pass_level2/Tensor_fill.cpp new file mode 100644 index 00000000000..da354487faa --- /dev/null +++ b/tools/pnnx/src/pass_level2/Tensor_fill.cpp @@ -0,0 +1,41 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class Tensor_fill : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 value +aten::fill op_0 2 1 input value out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.fill"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_fill, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/Tensor_index_put.cpp b/tools/pnnx/src/pass_level2/Tensor_index_put.cpp new file mode 100644 index 00000000000..cabda8ae27a --- /dev/null +++ b/tools/pnnx/src/pass_level2/Tensor_index_put.cpp @@ -0,0 +1,43 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class Tensor_index_put : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 indices +pnnx.Input input_2 0 1 values +prim::Constant op_0 0 1 accumulate value=%accumulate +aten::index_put op_1 4 1 input indices values accumulate out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.index_put"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_index_put, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/Tensor_masked_fill.cpp b/tools/pnnx/src/pass_level2/Tensor_masked_fill.cpp index ae23180f473..be20d57d002 100644 --- a/tools/pnnx/src/pass_level2/Tensor_masked_fill.cpp +++ b/tools/pnnx/src/pass_level2/Tensor_masked_fill.cpp @@ -26,7 +26,7 @@ class Tensor_masked_fill : public GraphRewriterPass pnnx.Input input_0 0 1 input pnnx.Input input_1 0 1 mask pnnx.Input input_2 0 1 value -aten::masked_fill op_1 3 1 input mask value out +aten::masked_fill op_0 3 1 input mask value out pnnx.Output output 1 0 out )PNNXIR"; } diff --git a/tools/pnnx/src/pass_level2/Tensor_to.cpp b/tools/pnnx/src/pass_level2/Tensor_to.cpp new file mode 100644 index 00000000000..52a7047105b --- /dev/null +++ b/tools/pnnx/src/pass_level2/Tensor_to.cpp @@ -0,0 +1,89 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class Tensor_to : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 6 +pnnx.Input input_0 0 1 input +prim::Constant op_0 0 1 dtype value=%dtype +prim::Constant op_1 0 1 non_blocking value=* +prim::Constant op_2 0 1 copy value=%copy +prim::Constant op_3 0 1 memory_format value=%memory_format +aten::to op_4 5 1 input dtype non_blocking copy memory_format out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.to"; + } + + void write(Operator* op, const std::map& captured_params) const + { + if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; + if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; + if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; + if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; + if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; + if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; + if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; + if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; + if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; + if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; + if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; + if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + + op->params["copy"] = captured_params.at("copy"); + + if (captured_params.at("memory_format").i == 0) + op->params["memory_format"] = "torch.contiguous_format"; + if (captured_params.at("memory_format").i == 1) + op->params["memory_format"] = "torch.preserve_format"; + if (captured_params.at("memory_format").i == 2) + op->params["memory_format"] = "torch.channels_last"; + } +}; + +class Tensor_to_1 : public Tensor_to +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +8 7 +pnnx.Input input_0 0 1 input +prim::Constant op_0 0 1 device value=* +prim::Constant op_1 0 1 dtype value=%dtype +prim::Constant op_2 0 1 non_blocking value=* +prim::Constant op_3 0 1 copy value=%copy +prim::Constant op_4 0 1 memory_format value=%memory_format +aten::to op_5 6 1 input device dtype non_blocking copy memory_format out +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_to, 20) +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_to_1, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/Tensor_type_as.cpp b/tools/pnnx/src/pass_level2/Tensor_type_as.cpp new file mode 100644 index 00000000000..e4607793d23 --- /dev/null +++ b/tools/pnnx/src/pass_level2/Tensor_type_as.cpp @@ -0,0 +1,41 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class Tensor_type_as : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 other +aten::type_as op_0 2 1 input other out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.type_as"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_type_as, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_einsum.cpp b/tools/pnnx/src/pass_level2/torch_einsum.cpp index f6b24757e50..86feb071dd1 100644 --- a/tools/pnnx/src/pass_level2/torch_einsum.cpp +++ b/tools/pnnx/src/pass_level2/torch_einsum.cpp @@ -47,7 +47,7 @@ class torch_einsum_1 : public GraphRewriterPass 5 4 pnnx.Input input_0 0 1 equation pnnx.Input input_1 0 1 operands -prim::Constant op_0 0 1 path value=None +pnnx.Input input_2 0 1 path aten::einsum op_1 3 1 equation operands path out pnnx.Output output 1 0 out )PNNXIR"; @@ -57,6 +57,13 @@ pnnx.Output output 1 0 out { return "torch.einsum"; } + + void write(Operator* op, const std::map& /*captured_params*/) const + { + // drop path input + op->inputs[2]->remove_consumer(op); + op->inputs.resize(2); + } }; REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_einsum_1, 20) diff --git a/tools/pnnx/src/pass_level2/torch_topk.cpp b/tools/pnnx/src/pass_level2/torch_topk.cpp new file mode 100644 index 00000000000..ae2f3d70ce0 --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_topk.cpp @@ -0,0 +1,44 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torch_topk : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 7 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 k +pnnx.Input input_2 0 1 dim +pnnx.Input input_3 0 1 largest +pnnx.Input input_4 0 1 sorted +aten::topk op_0 5 2 input k dim largest sorted values indices +pnnx.Output output 2 0 values indices +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.topk"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_topk, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/fuse_expression.cpp b/tools/pnnx/src/pass_level3/fuse_expression.cpp index 5b20a21018b..7009b9f8ae8 100644 --- a/tools/pnnx/src/pass_level3/fuse_expression.cpp +++ b/tools/pnnx/src/pass_level3/fuse_expression.cpp @@ -100,6 +100,7 @@ static bool operand_maybe_tensor(const Operand* operand) if (op->type == "aten::atan2" || op->type == "aten::div" || op->type == "aten::floor_divide" + || op->type == "aten::fmod" || op->type == "aten::mul" || op->type == "aten::pow" || op->type == "aten::remainder") @@ -363,7 +364,35 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip); expr += ")"; } - else if (op->type == "aten::to" || op->type == "aten::detach" || op->type == "aten::ScalarImplicit") + else if (op->type == "Tensor.to") + { + bool noop_type_cast = (op->outputs[0]->type != -1) && (op->inputs[0]->type == op->outputs[0]->type); + if (noop_type_cast) + { + fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip); + } + else + { + auto it = std::find(inputs.begin(), inputs.end(), operand); + if (it == inputs.end()) + { + // tensor + char tmp[32]; + sprintf(tmp, "@%d", (int)inputs.size()); + expr += tmp; + + inputs.push_back(operand); + } + else + { + // tensor + char tmp[32]; + sprintf(tmp, "@%d", (int)(it - inputs.begin())); + expr += tmp; + } + } + } + else if (op->type == "aten::detach" || op->type == "aten::ScalarImplicit") { fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip); } @@ -402,8 +431,8 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s expr += ")"; } else if (op->type == "aten::atan2" - || op->type == "aten::div" || op->type == "aten::floor_divide" + || op->type == "aten::fmod" || op->type == "aten::mul" || op->type == "aten::pow" || op->type == "aten::remainder") @@ -484,6 +513,27 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip); expr += ")"; } + else if (op->type == "aten::div") + { + std::string rounding_mode; + if (op->inputs.size() == 3) + fuse_expression(graph, op->inputs[2], rounding_mode, inputs, foldable_constants, zip); + + if (rounding_mode == "trunc") + { + expr += "floor_divide"; + } + else + { + expr += "div"; + } + + expr += "("; + fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip); + expr += ","; + fuse_expression(graph, op->inputs[1], expr, inputs, foldable_constants, zip); + expr += ")"; + } else { auto it = std::find(inputs.begin(), inputs.end(), operand); @@ -542,7 +592,13 @@ void fuse_expression(Graph& graph, const std::set& foldable_constan { need_fuse = true; } - if (op->type == "aten::to" || op->type == "aten::detach" || op->type == "aten::ScalarImplicit") + if (op->type == "Tensor.to") + { + // fuse noop type cast only + bool noop_to = (op->outputs[0]->type != -1) && (op->inputs[0]->type == op->outputs[0]->type); + need_fuse = noop_to; + } + if (op->type == "aten::detach" || op->type == "aten::ScalarImplicit") { need_fuse = true; } @@ -562,6 +618,7 @@ void fuse_expression(Graph& graph, const std::set& foldable_constan || op->type == "aten::exp" || op->type == "aten::floor" || op->type == "aten::floor_divide" + || op->type == "aten::fmod" || op->type == "aten::log" || op->type == "aten::log10" || op->type == "aten::mul" diff --git a/tools/pnnx/src/pass_level5.cpp b/tools/pnnx/src/pass_level5.cpp index 934a8eee42c..5d90c9554fa 100644 --- a/tools/pnnx/src/pass_level5.cpp +++ b/tools/pnnx/src/pass_level5.cpp @@ -27,6 +27,7 @@ #include "pass_level5/eliminate_noop_slice.h" #include "pass_level5/eliminate_noop_view_reshape.h" #include "pass_level5/eliminate_reshape_shape_expression.h" +#include "pass_level5/eliminate_type_as.h" #include "pass_level5/eval_expression.h" #include "pass_level5/fuse_adjacent_reshape.h" #include "pass_level5/fuse_channel_shuffle.h" @@ -112,6 +113,7 @@ void pass_level5(Graph& g, const std::set& foldable_constants, cons eliminate_noop_cat(g); eliminate_dropout(g); + eliminate_type_as(g); eliminate_noop_upsample(g); diff --git a/tools/pnnx/src/pass_level5/eliminate_type_as.cpp b/tools/pnnx/src/pass_level5/eliminate_type_as.cpp new file mode 100644 index 00000000000..c7290fb0480 --- /dev/null +++ b/tools/pnnx/src/pass_level5/eliminate_type_as.cpp @@ -0,0 +1,84 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "eliminate_type_as.h" + +#include +#include "pass_level2.h" + +namespace pnnx { + +void eliminate_type_as(Graph& graph) +{ + while (1) + { + bool matched = false; + + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + + if (op->type != "Tensor.type_as") + continue; + + if (op->inputs[0]->type == 0 || op->outputs[0]->type == 0) + continue; + + if (op->inputs[0]->type != op->outputs[0]->type) + continue; + + // delete noop-like type_as + matched = true; + + for (auto& x : op->inputs) + { + x->remove_consumer(op); + } + + Operand* type_as_out = op->outputs[0]; + + for (auto& x : type_as_out->consumers) + { + for (size_t j = 0; j < x->inputs.size(); j++) + { + if (x->inputs[j] == type_as_out) + x->inputs[j] = op->inputs[0]; + } + + op->inputs[0]->consumers.push_back(x); + } + + op->inputs[0]->name = type_as_out->name; + + type_as_out->producer = 0; + type_as_out->consumers.clear(); + + graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), type_as_out)); + delete type_as_out; + + op->inputs.clear(); + op->outputs.clear(); + + graph.ops.erase(graph.ops.begin() + i); + delete op; + + break; + } + + if (!matched) + break; + } +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/eliminate_type_as.h b/tools/pnnx/src/pass_level5/eliminate_type_as.h new file mode 100644 index 00000000000..46ec5ad571b --- /dev/null +++ b/tools/pnnx/src/pass_level5/eliminate_type_as.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void eliminate_type_as(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/eval_expression.cpp b/tools/pnnx/src/pass_level5/eval_expression.cpp index 11cda70117c..f091da8a73d 100644 --- a/tools/pnnx/src/pass_level5/eval_expression.cpp +++ b/tools/pnnx/src/pass_level5/eval_expression.cpp @@ -342,6 +342,7 @@ static std::string eval_expression(const Operator* op) || t == "mul" || t == "div" || t == "floor_divide" + || t == "fmod" || t == "pow" || t == "remainder") { @@ -380,6 +381,11 @@ static std::string eval_expression(const Operator* op) float r = af / bf; exprstack.push(std::to_string(r)); } + if (t == "fmod") + { + float r = fmod(af, bf); + exprstack.push(std::to_string(r)); + } if (t == "floor_divide") { int r = (int)af / (int)bf; diff --git a/tools/pnnx/src/pass_level5/fuse_select_to_unbind.cpp b/tools/pnnx/src/pass_level5/fuse_select_to_unbind.cpp index a76d741bb19..5a21f45c5db 100644 --- a/tools/pnnx/src/pass_level5/fuse_select_to_unbind.cpp +++ b/tools/pnnx/src/pass_level5/fuse_select_to_unbind.cpp @@ -37,6 +37,12 @@ void fuse_select_to_unbind(Graph& graph) if (input_rank == 0) continue; + if (input_rank == 1) + { + // skip select scalar + continue; + } + int dim = op->params.at("dim").i; const int select_dimsize = op_in->shape[dim]; diff --git a/tools/pnnx/src/pass_ncnn.cpp b/tools/pnnx/src/pass_ncnn.cpp index cd307b007cd..4bc69b379bb 100644 --- a/tools/pnnx/src/pass_ncnn.cpp +++ b/tools/pnnx/src/pass_ncnn.cpp @@ -22,6 +22,7 @@ #include "pass_ncnn/convert_torch_chunk.h" #include "pass_ncnn/convert_torch_einsum.h" #include "pass_ncnn/convert_torch_split.h" +#include "pass_ncnn/convert_torch_stack.h" #include "pass_ncnn/convert_torch_tensor_split.h" #include "pass_ncnn/convert_torch_unbind.h" #include "pass_ncnn/convert_Tensor_select.h" @@ -96,6 +97,7 @@ void pass_ncnn(Graph& g) ncnn::convert_torch_cat(g); ncnn::convert_torch_chunk(g); + ncnn::convert_torch_stack(g); ncnn::convert_torch_split(g); ncnn::convert_torch_unbind(g); ncnn::convert_torch_tensor_split(g); diff --git a/tools/pnnx/src/pass_ncnn/convert_torch_stack.cpp b/tools/pnnx/src/pass_ncnn/convert_torch_stack.cpp new file mode 100644 index 00000000000..aa56ab5f641 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/convert_torch_stack.cpp @@ -0,0 +1,91 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "convert_torch_stack.h" + +namespace pnnx { + +namespace ncnn { + +void convert_torch_stack(Graph& graph) +{ + int op_index = 0; + + while (1) + { + bool matched = false; + + for (Operator* op : graph.ops) + { + if (op->type != "torch.stack") + continue; + + matched = true; + + op->type = "Concat"; + op->name = std::string("stack_") + std::to_string(op_index++); + + const int batch_index = op->inputs[0]->params["__batch_index"].i; + + int axis = op->params.at("dim").i; + if (axis == batch_index) + { + fprintf(stderr, "stack along batch axis %d is not supported\n", batch_index); + continue; + } + + if (axis < 0) + { + int input_rank = op->inputs[0]->shape.size(); + axis = input_rank + axis; + } + + if (axis > batch_index) + axis -= 1; + + op->params["0"] = axis; + + op->params.erase("dim"); + + // reshape for output, expand the stack dim + { + Operand* out = op->outputs[0]; + + Operator* reshape = graph.new_operator_after("Tensor.reshape", op->name + "_ncnnreshape", op); + + Operand* reshape_in = graph.new_operand(op->name + "_ncnnreshape_in"); + + reshape->inputs.push_back(reshape_in); + reshape->outputs.push_back(out); + + op->outputs[0] = reshape_in; + + out->producer = reshape; + reshape_in->producer = op; + reshape_in->consumers.push_back(reshape); + + reshape->params["shape"] = out->shape; + } + + break; + } + + if (!matched) + break; + } +} + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/convert_torch_stack.h b/tools/pnnx/src/pass_ncnn/convert_torch_stack.h new file mode 100644 index 00000000000..40495b730d5 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/convert_torch_stack.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +void convert_torch_stack(Graph& graph); + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/expand_expression.cpp b/tools/pnnx/src/pass_ncnn/expand_expression.cpp index 28346c0683f..f6022be665f 100644 --- a/tools/pnnx/src/pass_ncnn/expand_expression.cpp +++ b/tools/pnnx/src/pass_ncnn/expand_expression.cpp @@ -178,7 +178,7 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx op_unary->inputs.push_back(op_unary_in); op_unary->outputs.push_back(op_unary_out); } - else if (t == "add" || t == "sub" || t == "mul" || t == "div" || /*t == "floor_divide" || */ t == "pow" || t == "atan2") + else if (t == "add" || t == "sub" || t == "mul" || t == "div" || t == "floor_divide" || t == "fmod" || t == "remainder" || t == "pow" || t == "atan2") { std::string a = exprstack.top(); exprstack.pop(); @@ -190,10 +190,16 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx Operator* op_binary = graph.new_operator_before("BinaryOp", t + "_" + std::to_string(pnnx_expr_index++), op); + // default todo type mark :[ + op_binary->params["0"] = -1; + if (t == "add") op_binary->params["0"] = 0; if (t == "sub") op_binary->params["0"] = 1; if (t == "mul") op_binary->params["0"] = 2; if (t == "div") op_binary->params["0"] = 3; + if (t == "floor_divide") fprintf(stderr, "BinaryOp floor_divide not supported yet\n"); // TODO + if (t == "fmod") fprintf(stderr, "BinaryOp fmod not supported yet\n"); // TODO + if (t == "remainder") fprintf(stderr, "BinaryOp remainder not supported yet\n"); // TODO if (t == "pow") op_binary->params["0"] = 6; if (t == "atan2") op_binary->params["0"] = 10; diff --git a/tools/pnnx/src/pass_ncnn/insert_reshape_linear.cpp b/tools/pnnx/src/pass_ncnn/insert_reshape_linear.cpp index 7702746d420..f97691506a4 100644 --- a/tools/pnnx/src/pass_ncnn/insert_reshape_linear.cpp +++ b/tools/pnnx/src/pass_ncnn/insert_reshape_linear.cpp @@ -94,19 +94,26 @@ void insert_reshape_linear(Graph& graph) reshape_h *= linear_in->shape[j]; } - std::vector reshape0_shape; + std::vector reshape0_out_shape; + std::vector reshape1_in_shape; if (batch_index == 0 && batch_index != 233) { - reshape0_shape = {1, reshape_h, linear_in->shape[input_rank - 1]}; + reshape0_out_shape = {1, reshape_h, linear_in->shape[input_rank - 1]}; + reshape1_in_shape = {1, reshape_h, linear_out->shape[input_rank - 1]}; } else { - reshape0_shape = {reshape_h, linear_in->shape[input_rank - 1]}; + reshape0_out_shape = {reshape_h, linear_in->shape[input_rank - 1]}; + reshape1_in_shape = {reshape_h, linear_out->shape[input_rank - 1]}; } - std::vector reshape1_shape = linear_out->shape; - - reshape0->params["shape"] = reshape0_shape; - reshape1->params["shape"] = reshape1_shape; + std::vector reshape1_out_shape = linear_out->shape; + + reshape0->params["shape"] = reshape0_out_shape; + reshape1->params["shape"] = reshape1_out_shape; + reshape0_out->type = linear_in->type; + reshape0_out->shape = reshape0_out_shape; + reshape1_in->type = linear_out->type; + reshape1_in->shape = reshape1_in_shape; break; } diff --git a/tools/pnnx/src/pass_ncnn/nn_Linear.cpp b/tools/pnnx/src/pass_ncnn/nn_Linear.cpp index d828afab1a2..6a6edbeab2f 100644 --- a/tools/pnnx/src/pass_ncnn/nn_Linear.cpp +++ b/tools/pnnx/src/pass_ncnn/nn_Linear.cpp @@ -18,6 +18,152 @@ namespace pnnx { namespace ncnn { +class nn_Linear_0 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input #input=(1,%m,%in_features)f32 +nn.Linear op_0 1 1 input out in_features=%in_features out_features=%out_features bias=%bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Gemm"; + } + + const char* name_str() const + { + return "gemm"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + op->params["2"] = 0; + op->params["3"] = 1; + op->params["4"] = 0; + op->params["5"] = 1; + op->params["6"] = 1; + op->params["7"] = captured_params.at("m"); + op->params["8"] = captured_params.at("out_features"); + op->params["9"] = captured_params.at("in_features"); + op->params["10"] = captured_params.at("bias").b ? 4 : -1; + + op->attrs["0"] = Attribute(); + op->attrs["0"].data = {0, 0, 0, 0}; + op->attrs["1"] = captured_attrs.at("op_0.weight"); + if (captured_params.at("bias").b) + { + op->attrs["2"] = Attribute(); + op->attrs["2"].data = {0, 0, 0, 0}; + op->attrs["3"] = captured_attrs.at("op_0.bias"); + } + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Linear_0, 19) + +class nn_Linear_01 : public nn_Linear_0 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input #input=(%m,%in_features)f32 +nn.Linear op_0 1 1 input out in_features=%in_features out_features=%out_features bias=%bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + bool match(const std::map& captured_params) const + { + const int m = captured_params.at("m").i; + + if (m == 1) + return false; + + return true; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Linear_01, 19) + +class nn_Linear_10 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input #input=(1,%m,%in_features)f32 +pnnx.Input input_1 0 1 bias +nn.Linear op_0 2 1 input bias out in_features=%in_features out_features=%out_features bias=False +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Gemm"; + } + + const char* name_str() const + { + return "gemm"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + op->params["2"] = 0; + op->params["3"] = 1; + op->params["4"] = 0; + op->params["5"] = 1; + op->params["6"] = 0; + op->params["7"] = captured_params.at("m"); + op->params["8"] = captured_params.at("out_features"); + op->params["9"] = captured_params.at("in_features"); + op->params["10"] = 4; + + op->attrs["0"] = Attribute(); + op->attrs["0"].data = {0, 0, 0, 0}; + op->attrs["1"] = captured_attrs.at("op_0.weight"); + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Linear_10, 19) + +class nn_Linear_11 : public nn_Linear_10 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input #input=(%m,%in_features)f32 +pnnx.Input input_1 0 1 bias +nn.Linear op_0 2 1 input bias out in_features=%in_features out_features=%out_features bias=False +pnnx.Output output 1 0 out +)PNNXIR"; + } + + bool match(const std::map& captured_params) const + { + const int m = captured_params.at("m").i; + + if (m == 1) + return false; + + return true; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Linear_11, 19) + class nn_Linear : public GraphRewriterPass { public: @@ -57,6 +203,52 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Linear, 20) +class nn_Linear_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 bias +nn.Linear op_0 2 1 input bias out in_features=%in_features out_features=%out_features bias=False +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 bias +InnerProduct linear 1 1 input a +BinaryOp bias 2 1 a bias out 0=0 +pnnx.Output output 1 0 out +)PNNXIR"; + } + + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const + { + GraphRewriterPass::write(ops, captured_params, captured_attrs); + + const int batch_index = ops.at("linear")->inputs[0]->params["__batch_index"].i; + + ops.at("linear")->params["0"] = captured_params.at("out_features"); + ops.at("linear")->params["1"] = 0; + ops.at("linear")->params["2"] = captured_attrs.at("op_0.weight").elemcount(); + + ops.at("linear")->attrs["0"] = Attribute(); + ops.at("linear")->attrs["0"].data = {0, 0, 0, 0}; + ops.at("linear")->attrs["1"] = captured_attrs.at("op_0.weight"); + + ops.at("linear")->outputs[0]->params["__batch_index"] = batch_index; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Linear_1, 20) + } // namespace ncnn } // namespace pnnx diff --git a/tools/pnnx/tests/CMakeLists.txt b/tools/pnnx/tests/CMakeLists.txt index 3f9b5367502..626d549991d 100644 --- a/tools/pnnx/tests/CMakeLists.txt +++ b/tools/pnnx/tests/CMakeLists.txt @@ -162,7 +162,9 @@ pnnx_add_test(nn_ZeroPad2d) pnnx_add_test(Tensor_contiguous) pnnx_add_test(Tensor_expand) +pnnx_add_test(Tensor_fill) pnnx_add_test(Tensor_index) +pnnx_add_test(Tensor_index_put) pnnx_add_test(Tensor_masked_fill) pnnx_add_test(Tensor_new_empty) pnnx_add_test(Tensor_new_full) @@ -173,6 +175,8 @@ pnnx_add_test(Tensor_reshape) pnnx_add_test(Tensor_select) pnnx_add_test(Tensor_slice) pnnx_add_test(Tensor_slice_copy) +pnnx_add_test(Tensor_to) +pnnx_add_test(Tensor_type_as) pnnx_add_test(Tensor_view) pnnx_add_test(torch_addmm) @@ -221,6 +225,7 @@ pnnx_add_test(torch_squeeze) pnnx_add_test(torch_stack) pnnx_add_test(torch_std) pnnx_add_test(torch_tensor_split) +pnnx_add_test(torch_topk) pnnx_add_test(torch_transpose) pnnx_add_test(torch_unbind) pnnx_add_test(torch_unsqueeze) @@ -295,6 +300,7 @@ pnnx_add_test(pnnx_eliminate_noop_cat) pnnx_add_test(pnnx_eliminate_noop_expand) pnnx_add_test(pnnx_eliminate_noop_math) pnnx_add_test(pnnx_eliminate_noop_upsample) +pnnx_add_test(pnnx_expression) pnnx_add_test(pnnx_fold_constant) pnnx_add_test(pnnx_fuse_conv1d_batchnorm1d) pnnx_add_test(pnnx_fuse_conv2d_batchnorm2d) diff --git a/tools/pnnx/tests/ncnn/CMakeLists.txt b/tools/pnnx/tests/ncnn/CMakeLists.txt index fd059eef62d..945576bfaf6 100644 --- a/tools/pnnx/tests/ncnn/CMakeLists.txt +++ b/tools/pnnx/tests/ncnn/CMakeLists.txt @@ -154,6 +154,7 @@ pnnx_ncnn_add_test(torch_permute) pnnx_ncnn_add_test(torch_prod) pnnx_ncnn_add_test(torch_sum) pnnx_ncnn_add_test(torch_squeeze) +pnnx_ncnn_add_test(torch_stack) pnnx_ncnn_add_test(torch_tensor_split) pnnx_ncnn_add_test(torch_transpose) pnnx_ncnn_add_test(torch_unbind) diff --git a/tools/pnnx/tests/ncnn/test_torch_stack.py b/tools/pnnx/tests/ncnn/test_torch_stack.py new file mode 100644 index 00000000000..74d9e2a9ea3 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_torch_stack.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + out0 = torch.stack((x, y), dim=0) + out1 = torch.stack((z, w), dim=2) + out0.relu_() + out1.relu_() + return out0, out1 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(3, 16) + y = torch.rand(3, 16) + z = torch.rand(5, 9, 3) + w = torch.rand(5, 9, 3) + + a0, a1 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_torch_stack.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_torch_stack.pt inputshape=[3,16],[3,16],[5,9,3],[5,9,3]") + + # ncnn inference + import test_torch_stack_ncnn + b0, b1 = test_torch_stack_ncnn.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_Tensor_fill.py b/tools/pnnx/tests/test_Tensor_fill.py new file mode 100644 index 00000000000..2c80d56c414 --- /dev/null +++ b/tools/pnnx/tests/test_Tensor_fill.py @@ -0,0 +1,57 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x[:2,:].fill_(z[0]) + y[:1,:].fill_(0.22) + return x + y.fill_(7) + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(6, 16) + y = torch.rand(6, 16) + z = torch.rand(1) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_Tensor_fill.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_Tensor_fill.pt inputshape=[6,16],[6,16],[1]") + + # pnnx inference + import test_Tensor_fill_pnnx + b = test_Tensor_fill_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_Tensor_index_put.py b/tools/pnnx/tests/test_Tensor_index_put.py new file mode 100644 index 00000000000..968c785b0a9 --- /dev/null +++ b/tools/pnnx/tests/test_Tensor_index_put.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = x.clone() + z = z.clone() + x = x.index_put(indices=[torch.tensor([10,2])], values=y, accumulate=False) + z.index_put_(indices=[torch.tensor([1,0,0]), torch.tensor([3,2,1])], values=w, accumulate=True) + return x, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(12) + y = torch.rand(2) + z = torch.rand(6,9) + w = torch.rand(3) + + a = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_Tensor_index_put.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_Tensor_index_put.pt inputshape=[12],[2],[6,9],[3]") + + # pnnx inference + import test_Tensor_index_put_pnnx + b = test_Tensor_index_put_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_Tensor_to.py b/tools/pnnx/tests/test_Tensor_to.py new file mode 100644 index 00000000000..71c157cb341 --- /dev/null +++ b/tools/pnnx/tests/test_Tensor_to.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y): + x = x * 10 + y = y * 13 + y = y.to(dtype=x.dtype, memory_format=torch.contiguous_format) + x = x.to(device='cpu', dtype=torch.int, copy=True) + x = x + 1 + y = y - 2 + return x, y + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(3, 16) + y = torch.randint(10, (1, 13), dtype=torch.int) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_Tensor_to.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_Tensor_to.pt inputshape=[3,16],[1,13]i32") + + # pnnx inference + import test_Tensor_to_pnnx + b = test_Tensor_to_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_Tensor_type_as.py b/tools/pnnx/tests/test_Tensor_type_as.py new file mode 100644 index 00000000000..cdb7a540072 --- /dev/null +++ b/tools/pnnx/tests/test_Tensor_type_as.py @@ -0,0 +1,65 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = x * 100 + z = z * 200 + x = x.type_as(y) + x = F.relu(x) + x = x.type_as(z) + z = F.relu(z) + z = z.type_as(x) + return x, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(3, 16) + y = torch.randint(10, (1, 13), dtype=torch.int) + z = torch.rand(8, 5, 9, 10) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_Tensor_type_as.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_Tensor_type_as.pt inputshape=[3,16],[1,13]i32,[8,5,9,10]") + + # pnnx inference + import test_Tensor_type_as_pnnx + b = test_Tensor_type_as_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_pnnx_expression.py b/tools/pnnx/tests/test_pnnx_expression.py new file mode 100644 index 00000000000..7a7dc9967bd --- /dev/null +++ b/tools/pnnx/tests/test_pnnx_expression.py @@ -0,0 +1,75 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from packaging import version + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.w0 = nn.Parameter(torch.rand(12, 15)) + self.w1 = nn.Parameter(torch.rand(12, 15)) + self.w2 = nn.Parameter(torch.rand(12, 15)) + self.w3 = nn.Parameter(torch.rand(12, 15)) + self.w4 = nn.Parameter(torch.rand(12, 15)) + self.w5 = nn.Parameter(torch.rand(12, 15)) + + def forward(self, x): + x0 = x * 10 + x = x + self.w0 + x0 + x = x - self.w1 + x0.float() + x = x * self.w2 + x0 + x = x / self.w3 + x0 + x = x // self.w4 + x0 + if version.parse(torch.__version__) >= version.parse('2.0'): + x = x % self.w5 + x0 + else: + x = torch.fmod(x, self.w5) + x0 + y = x.int() + return x, y & 3, y | 3, y ^ 3, y << 3, y >> 3 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(12, 15) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_pnnx_expression.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_pnnx_expression.pt inputshape=[12,15]") + + # pnnx inference + import test_pnnx_expression_pnnx + b = test_pnnx_expression_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_torch_einsum.py b/tools/pnnx/tests/test_torch_einsum.py index cde16e3beee..c045f4b8980 100644 --- a/tools/pnnx/tests/test_torch_einsum.py +++ b/tools/pnnx/tests/test_torch_einsum.py @@ -148,7 +148,10 @@ def test(): b = test_torch_einsum_pnnx.test_inference() for a0, b0 in zip(a, b): - if not torch.equal(a0, b0): + # allclose may auto broadcast compare + if a0.shape != b0.shape: + return False + if not torch.allclose(a0, b0, 1e-4, 1e-4): return False return True diff --git a/tools/pnnx/tests/test_torch_topk.py b/tools/pnnx/tests/test_torch_topk.py new file mode 100644 index 00000000000..e2be60aae86 --- /dev/null +++ b/tools/pnnx/tests/test_torch_topk.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x, _ = torch.topk(x, 4) + y, _ = torch.topk(y, k=1, dim=2, largest=False) + z, indices = torch.topk(z, k=3, dim=-1, sorted=False) + return x, y, z, indices + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(14, 8, 5, 9, 10) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_torch_topk.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_topk.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") + + # pnnx inference + import test_torch_topk_pnnx + b = test_torch_topk_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)