Skip to content

Commit

Permalink
pnnx fix build, prepend batch for broadcast reshape (#4841)
Browse files Browse the repository at this point in the history
* fix build, prepend batch for broadcast reshape

* sanitize filename

* do not fuse to eltwise if broadcast
  • Loading branch information
nihui authored Jul 6, 2023
1 parent 47e0daf commit 91090d7
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 5 deletions.
5 changes: 4 additions & 1 deletion tools/pnnx/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@

static std::string get_basename(const std::string& path)
{
return path.substr(0, path.find_last_of('.'));
std::string base = path.substr(0, path.find_last_of('.'));
// sanitize -
std::replace(base.begin(), base.end(), '-', '_');
return base;
}

static void parse_string_list(char* s, std::vector<std::string>& list)
Expand Down
6 changes: 2 additions & 4 deletions tools/pnnx/src/pass_level5/eval_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#include "eval_expression.h"

#include <fenv.h>
#include <float.h>
#include <math.h>

#include <iostream>
Expand Down Expand Up @@ -276,14 +278,10 @@ static std::string eval_expression(const Operator* op)
if (t == "round")
{
// round to nearest even
#if FLT_ROUNDS != FE_TONEAREST
int old_rm = fegetround();
fesetround(FE_TONEAREST);
#endif
float r = nearbyintf(af);
#if FLT_ROUNDS != FE_TONEAREST
fesetround(old_rm);
#endif
exprstack.push(std::to_string(r));
}
if (t == "rsqrt")
Expand Down
21 changes: 21 additions & 0 deletions tools/pnnx/src/pass_ncnn/fuse_binaryop_eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ pnnx.Output output 1 0 out
return "weighted_sum";
}

bool match(const std::map<std::string, const Operator*>& matched_operators) const
{
auto a_shape = matched_operators.at("op_0")->inputs[0]->shape;
auto b_shape = matched_operators.at("op_1")->inputs[0]->shape;
return !a_shape.empty() && a_shape == b_shape;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& /*captured_attrs*/) const
{
float c0 = 1.f;
Expand Down Expand Up @@ -93,6 +100,13 @@ pnnx.Output output 1 0 out
return "weighted_sum";
}

bool match(const std::map<std::string, const Operator*>& matched_operators) const
{
auto a_shape = matched_operators.at("op_0")->inputs[0]->shape;
auto b_shape = matched_operators.at("op_1")->inputs[1]->shape;
return !a_shape.empty() && a_shape == b_shape;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& /*captured_attrs*/) const
{
float c0 = 1.f;
Expand Down Expand Up @@ -133,6 +147,13 @@ pnnx.Output output 1 0 out
return "weighted_sum";
}

bool match(const std::map<std::string, const Operator*>& matched_operators) const
{
auto a_shape = matched_operators.at("op_1")->inputs[0]->shape;
auto b_shape = matched_operators.at("op_0")->inputs[0]->shape;
return !a_shape.empty() && a_shape == b_shape;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& /*captured_attrs*/) const
{
float c0 = 1.f;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,11 @@ void insert_reshape_numpy_binaryop_broadcast(Graph& graph)
reshape0_shape.insert(reshape0_shape.begin(), 1);
}

if (batch_index0 != 233)
{
reshape0_shape.insert(reshape0_shape.begin() + batch_index0, 1);
}

reshape0->params["shape"] = reshape0_shape;

break;
Expand Down

0 comments on commit 91090d7

Please sign in to comment.