Skip to content

Commit

Permalink
Fix missing approximate parameters of nn.GELU
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed May 9, 2024
1 parent 8994bd0 commit 53299c5
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
4 changes: 2 additions & 2 deletions tools/pnnx/src/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2647,7 +2647,7 @@ int insert_function(FILE* pyfp, std::vector<std::string>& custom_ops_names, std:
std::ostringstream functionContent;
bool insideFunction = false;

std::regex functionStartRegex(R "(\s*def\s+(\w+)\s*\(.*\):)");
std::regex functionStartRegex(R"(\s*def\s+(\w+)\s*\(.*\):)");
std::regex pattern("^import");
std::smatch match;

Expand Down Expand Up @@ -2715,7 +2715,7 @@ int get_custom_op_names(std::string& customop_infer_py, std::vector<std::string>

std::string line;

std::regex functionStartRegex(R "(\s*def\s+(\w+)\s*\(.*\):)");
std::regex functionStartRegex(R"(\s*def\s+(\w+)\s*\(.*\):)");
std::smatch match;

while (std::getline(file, line))
Expand Down
9 changes: 8 additions & 1 deletion tools/pnnx/src/pass_level1/nn_GELU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"
namespace pnnx {

class GELU : public FuseModulePass
Expand All @@ -28,6 +28,13 @@ class GELU : public FuseModulePass
{
return "nn.GELU";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const
{
const torch::jit::Node* celu = find_node_by_kind(graph, "aten::gelu");

op->params["approximate"] = celu->namedInput("approximate");
}
};

REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(GELU)
Expand Down

0 comments on commit 53299c5

Please sign in to comment.