From 8cade5b88fddb73b0fccb20b1fa9bcf9a1fb313f Mon Sep 17 00:00:00 2001 From: "sen.li" Date: Thu, 9 May 2024 10:12:51 +0800 Subject: [PATCH] update gelu param name --- tools/pnnx/src/ir.cpp | 4 ++-- tools/pnnx/src/pass_level1/nn_GELU.cpp | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 860ac2208a7..4c9c2771206 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -2647,7 +2647,7 @@ int insert_function(FILE* pyfp, std::vector& custom_ops_names, std: std::ostringstream functionContent; bool insideFunction = false; - std::regex functionStartRegex(R "(\s*def\s+(\w+)\s*\(.*\):)"); + std::regex functionStartRegex(R"(\s*def\s+(\w+)\s*\(.*\):)"); std::regex pattern("^import"); std::smatch match; @@ -2715,7 +2715,7 @@ int get_custom_op_names(std::string& customop_infer_py, std::vector std::string line; - std::regex functionStartRegex(R "(\s*def\s+(\w+)\s*\(.*\):)"); + std::regex functionStartRegex(R"(\s*def\s+(\w+)\s*\(.*\):)"); std::smatch match; while (std::getline(file, line)) diff --git a/tools/pnnx/src/pass_level1/nn_GELU.cpp b/tools/pnnx/src/pass_level1/nn_GELU.cpp index d0d295a6e45..71ccc9568ff 100644 --- a/tools/pnnx/src/pass_level1/nn_GELU.cpp +++ b/tools/pnnx/src/pass_level1/nn_GELU.cpp @@ -31,9 +31,9 @@ class GELU : public FuseModulePass void write(Operator* op, const std::shared_ptr& graph) const { - const torch::jit::Node* celu = find_node_by_kind(graph, "aten::gelu"); + const torch::jit::Node* gelu = find_node_by_kind(graph, "aten::gelu"); - op->params["approximate"] = celu->namedInput("approximate"); + op->params["approximate"] = gelu->namedInput("approximate"); } };