Skip to content

Commit

Permalink
update gelu param name
Browse files Browse the repository at this point in the history
  • Loading branch information
sen.li committed May 9, 2024
1 parent f36677c commit 8cade5b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions tools/pnnx/src/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2647,7 +2647,7 @@ int insert_function(FILE* pyfp, std::vector<std::string>& custom_ops_names, std:
std::ostringstream functionContent;
bool insideFunction = false;

std::regex functionStartRegex(R "(\s*def\s+(\w+)\s*\(.*\):)");
std::regex functionStartRegex(R"(\s*def\s+(\w+)\s*\(.*\):)");
std::regex pattern("^import");
std::smatch match;

Expand Down Expand Up @@ -2715,7 +2715,7 @@ int get_custom_op_names(std::string& customop_infer_py, std::vector<std::string>

std::string line;

std::regex functionStartRegex(R "(\s*def\s+(\w+)\s*\(.*\):)");
std::regex functionStartRegex(R"(\s*def\s+(\w+)\s*\(.*\):)");
std::smatch match;

while (std::getline(file, line))
Expand Down
4 changes: 2 additions & 2 deletions tools/pnnx/src/pass_level1/nn_GELU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ class GELU : public FuseModulePass

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
{
const torch::jit::Node* celu = find_node_by_kind(graph, "aten::gelu");
const torch::jit::Node* gelu = find_node_by_kind(graph, "aten::gelu");

op->params["approximate"] = celu->namedInput("approximate");
op->params["approximate"] = gelu->namedInput("approximate");
}
};

Expand Down

0 comments on commit 8cade5b

Please sign in to comment.