Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PASS]fix conv_elementwise_tree_fuse_pass rm bug #6812

Merged
merged 2 commits into from
Sep 3, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void ConvElementwiseTreeFusePass::Apply(
<< " elementwise_type: " << elementwise_type;
fusion::ConvElementwiseTreeFuser fuser(
conv_type, conv_has_bias, conv_has_prelu_alpha, elementwise_type);
fuser(graph.get());
fuser.apply_impl(graph.get());
}
}
}
Expand Down
16 changes: 10 additions & 6 deletions lite/core/optimizer/mir/fusion/conv_elementwise_tree_fuser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ void ConvElementwiseTreeFuser::BuildPattern() {
auto* conv_input =
VarNode("conv_input")->assert_is_op_input(conv_type_, "Input")->AsInput();
auto* conv_filter = VarNode("conv_filter")
->assert_is_persistable_var()
->assert_is_op_input(conv_type_, "Filter")
->AsInput();
auto* elementwise_input = VarNode("elementwise_input")
Expand All @@ -35,7 +36,8 @@ void ConvElementwiseTreeFuser::BuildPattern() {
// create intermediate nodes
conv_output_ = VarNode("conv_output")
->assert_is_op_output(conv_type_, "Output")
->assert_is_op_input(elementwise_type_, "Y");
->assert_is_op_input(elementwise_type_, "Y")
->assert_only_one_output();

// create op nodes
// The pass will not been applied if conv1x1 has already applied this pass.
Expand Down Expand Up @@ -79,13 +81,15 @@ void ConvElementwiseTreeFuser::BuildPattern() {
// consider two special cases: conv with bias, conv with prelu alpha
std::vector<PMNode*> conv_inputs{conv_input, conv_filter};
if (conv_has_bias_) {
auto* conv_bias =
VarNode("conv_bias")->assert_is_op_input(conv_type_, "Bias");
auto* conv_bias = VarNode("conv_bias")
->assert_is_op_input(conv_type_, "Bias")
->assert_is_persistable_var();
conv_inputs.push_back(conv_bias);
}
if (conv_has_prelu_alpha_) {
auto* conv_alpha = VarNode("conv_alpha")
->assert_is_op_input(conv_type_, "Prelu_alpha")
->assert_is_persistable_var()
->AsInput();
conv_inputs.push_back(conv_alpha);
}
Expand Down Expand Up @@ -146,9 +150,9 @@ void ConvElementwiseTreeFuser::InsertNewNode(SSAGraph* graph,
}

// NOTE: Mark these node as intermediate at this place.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的NOTE同步也修改下吧

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

conv_output_->AsIntermediate();
conv_->AsIntermediate();
elementwise_->AsIntermediate();
nodes2rm_.insert(matched.at("conv"));
nodes2rm_.insert(matched.at("conv_output"));
nodes2rm_.insert(matched.at("elementwise"));

auto op_desc = GenOpDesc(matched);
auto conv_op_new = LiteOpRegistry::Global().Create(conv_type_);
Expand Down
13 changes: 13 additions & 0 deletions lite/core/optimizer/mir/fusion/conv_elementwise_tree_fuser.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#pragma once

#include <memory>
#include <set>
#include <string>
#include "lite/core/optimizer/mir/pattern_matcher_high_api.h"

Expand All @@ -34,6 +35,17 @@ class ConvElementwiseTreeFuser : public FuseBase {
conv_has_prelu_alpha_ = conv_has_prelu_alpha;
elementwise_type_ = elementwise_type;
}
size_t apply_impl(SSAGraph* graph) {
BuildPattern();
PerformPatternMatcher(graph);

for (const auto& matched : key2nodes_) {
InsertNewNode(graph, matched);
}

GraphSafeRemoveNodes(graph, nodes2rm_);
return key2nodes_.size();
}

void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
Expand All @@ -48,6 +60,7 @@ class ConvElementwiseTreeFuser : public FuseBase {
PMNode* conv_output_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这三个PMNode*类型的成员变量可以去掉,改为临时变量。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

PMNode* conv_;
PMNode* elementwise_;
std::set<const Node*> nodes2rm_;
};

} // namespace fusion
Expand Down