Skip to content

Commit

Permalink
pnnx fuse moduleop unpack
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Aug 23, 2023
1 parent 77c4421 commit 02ed183
Showing 1 changed file with 53 additions and 0 deletions.
53 changes: 53 additions & 0 deletions tools/pnnx/src/pass_level1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,56 @@ FuseModulePassRegister::~FuseModulePassRegister()
delete pass;
}

static void fuse_moduleop_unpack(Graph& graph, const std::vector<std::string>& module_operators)
{
while (1)
{
bool matched = false;

for (size_t i = 0; i < graph.ops.size(); i++)
{
Operator* op = graph.ops[i];

if (std::find(module_operators.begin(), module_operators.end(), op->type) == module_operators.end())
continue;

if (op->outputs.size() != 1)
continue;

if (op->outputs[0]->consumers.size() != 1)
continue;

Operator* op2 = op->outputs[0]->consumers[0];
if (op2->type != "prim::TupleUnpack")
continue;

matched = true;

op->outputs[0]->producer = 0;
op->outputs[0]->remove_consumer(op2);

for (auto& x : op2->outputs)
{
x->producer = op;
}

op->outputs = op2->outputs;

op2->inputs.clear();
op2->outputs.clear();

graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op2));

delete op2;

break;
}

if (!matched)
break;
}
}

void pass_level1(const torch::jit::Module& mod, const std::shared_ptr<torch::jit::Graph>& g, const std::vector<std::string>& module_operators, Graph& pg)
{
for (int i = 1; i < (int)g->inputs().size(); i++)
Expand Down Expand Up @@ -407,6 +457,9 @@ void pass_level1(const torch::jit::Module& mod, const std::shared_ptr<torch::jit
r->consumers.push_back(op);
op->inputs.push_back(r);
}

// post process
fuse_moduleop_unpack(pg, module_operators);
}

} // namespace pnnx

0 comments on commit 02ed183

Please sign in to comment.