From 91090d793b80d4bd8453641607acfa984a9be383 Mon Sep 17 00:00:00 2001 From: nihui Date: Thu, 6 Jul 2023 18:55:23 +0800 Subject: [PATCH] pnnx fix build, prepend batch for broadcast reshape (#4841) * fix build, prepend batch for broadcast reshape * sanitize filename * do not fuse to eltwise if broadcast --- tools/pnnx/src/main.cpp | 5 ++++- .../pnnx/src/pass_level5/eval_expression.cpp | 6 ++---- .../src/pass_ncnn/fuse_binaryop_eltwise.cpp | 21 +++++++++++++++++++ ...nsert_reshape_numpy_binaryop_broadcast.cpp | 5 +++++ 4 files changed, 32 insertions(+), 5 deletions(-) diff --git a/tools/pnnx/src/main.cpp b/tools/pnnx/src/main.cpp index e5253a97208..f2b9a545354 100644 --- a/tools/pnnx/src/main.cpp +++ b/tools/pnnx/src/main.cpp @@ -47,7 +47,10 @@ static std::string get_basename(const std::string& path) { - return path.substr(0, path.find_last_of('.')); + std::string base = path.substr(0, path.find_last_of('.')); + // sanitize - + std::replace(base.begin(), base.end(), '-', '_'); + return base; } static void parse_string_list(char* s, std::vector& list) diff --git a/tools/pnnx/src/pass_level5/eval_expression.cpp b/tools/pnnx/src/pass_level5/eval_expression.cpp index a48b4c97675..11cda70117c 100644 --- a/tools/pnnx/src/pass_level5/eval_expression.cpp +++ b/tools/pnnx/src/pass_level5/eval_expression.cpp @@ -14,6 +14,8 @@ #include "eval_expression.h" +#include +#include #include #include @@ -276,14 +278,10 @@ static std::string eval_expression(const Operator* op) if (t == "round") { // round to nearest even -#if FLT_ROUNDS != FE_TONEAREST int old_rm = fegetround(); fesetround(FE_TONEAREST); -#endif float r = nearbyintf(af); -#if FLT_ROUNDS != FE_TONEAREST fesetround(old_rm); -#endif exprstack.push(std::to_string(r)); } if (t == "rsqrt") diff --git a/tools/pnnx/src/pass_ncnn/fuse_binaryop_eltwise.cpp b/tools/pnnx/src/pass_ncnn/fuse_binaryop_eltwise.cpp index c3b9e13c3c4..bee5c157386 100644 --- a/tools/pnnx/src/pass_ncnn/fuse_binaryop_eltwise.cpp +++ b/tools/pnnx/src/pass_ncnn/fuse_binaryop_eltwise.cpp @@ -48,6 +48,13 @@ pnnx.Output output 1 0 out return "weighted_sum"; } + bool match(const std::map& matched_operators) const + { + auto a_shape = matched_operators.at("op_0")->inputs[0]->shape; + auto b_shape = matched_operators.at("op_1")->inputs[0]->shape; + return !a_shape.empty() && a_shape == b_shape; + } + void write(Operator* op, const std::map& captured_params, const std::map& /*captured_attrs*/) const { float c0 = 1.f; @@ -93,6 +100,13 @@ pnnx.Output output 1 0 out return "weighted_sum"; } + bool match(const std::map& matched_operators) const + { + auto a_shape = matched_operators.at("op_0")->inputs[0]->shape; + auto b_shape = matched_operators.at("op_1")->inputs[1]->shape; + return !a_shape.empty() && a_shape == b_shape; + } + void write(Operator* op, const std::map& captured_params, const std::map& /*captured_attrs*/) const { float c0 = 1.f; @@ -133,6 +147,13 @@ pnnx.Output output 1 0 out return "weighted_sum"; } + bool match(const std::map& matched_operators) const + { + auto a_shape = matched_operators.at("op_1")->inputs[0]->shape; + auto b_shape = matched_operators.at("op_0")->inputs[0]->shape; + return !a_shape.empty() && a_shape == b_shape; + } + void write(Operator* op, const std::map& captured_params, const std::map& /*captured_attrs*/) const { float c0 = 1.f; diff --git a/tools/pnnx/src/pass_ncnn/insert_reshape_numpy_binaryop_broadcast.cpp b/tools/pnnx/src/pass_ncnn/insert_reshape_numpy_binaryop_broadcast.cpp index 50b39f8d72b..6eb46308d7b 100644 --- a/tools/pnnx/src/pass_ncnn/insert_reshape_numpy_binaryop_broadcast.cpp +++ b/tools/pnnx/src/pass_ncnn/insert_reshape_numpy_binaryop_broadcast.cpp @@ -138,6 +138,11 @@ void insert_reshape_numpy_binaryop_broadcast(Graph& graph) reshape0_shape.insert(reshape0_shape.begin(), 1); } + if (batch_index0 != 233) + { + reshape0_shape.insert(reshape0_shape.begin() + batch_index0, 1); + } + reshape0->params["shape"] = reshape0_shape; break;