From 02ed183432cfffcb9c1a49188228db1445b35a3d Mon Sep 17 00:00:00 2001 From: nihuini Date: Wed, 23 Aug 2023 14:42:42 +0800 Subject: [PATCH] pnnx fuse moduleop unpack --- tools/pnnx/src/pass_level1.cpp | 53 ++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/tools/pnnx/src/pass_level1.cpp b/tools/pnnx/src/pass_level1.cpp index c47ce0934e4..844b55426be 100644 --- a/tools/pnnx/src/pass_level1.cpp +++ b/tools/pnnx/src/pass_level1.cpp @@ -50,6 +50,56 @@ FuseModulePassRegister::~FuseModulePassRegister() delete pass; } +static void fuse_moduleop_unpack(Graph& graph, const std::vector& module_operators) +{ + while (1) + { + bool matched = false; + + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + + if (std::find(module_operators.begin(), module_operators.end(), op->type) == module_operators.end()) + continue; + + if (op->outputs.size() != 1) + continue; + + if (op->outputs[0]->consumers.size() != 1) + continue; + + Operator* op2 = op->outputs[0]->consumers[0]; + if (op2->type != "prim::TupleUnpack") + continue; + + matched = true; + + op->outputs[0]->producer = 0; + op->outputs[0]->remove_consumer(op2); + + for (auto& x : op2->outputs) + { + x->producer = op; + } + + op->outputs = op2->outputs; + + op2->inputs.clear(); + op2->outputs.clear(); + + graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op2)); + + delete op2; + + break; + } + + if (!matched) + break; + } +} + void pass_level1(const torch::jit::Module& mod, const std::shared_ptr& g, const std::vector& module_operators, Graph& pg) { for (int i = 1; i < (int)g->inputs().size(); i++) @@ -407,6 +457,9 @@ void pass_level1(const torch::jit::Module& mod, const std::shared_ptrconsumers.push_back(op); op->inputs.push_back(r); } + + // post process + fuse_moduleop_unpack(pg, module_operators); } } // namespace pnnx