Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cherry-pick2.9]add fuse flatten_contiguous_range and fc pass #6057

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lite/api/paddle_use_passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ USE_MIR_PASS(control_flow_op_shared_inputs_and_outputs_place_sync_pass);
USE_MIR_PASS(lite_scale_activation_fuse_pass);
USE_MIR_PASS(lite_instance_norm_activation_fuse_pass);
USE_MIR_PASS(ssd_boxes_calc_offline_pass);
USE_MIR_PASS(lite_flatten_fc_fuse_pass);
USE_MIR_PASS(lite_fc_prelu_fuse_pass);
USE_MIR_PASS(__xpu__graph_dedup_pass);
USE_MIR_PASS(__xpu__resnet_fuse_pass);
Expand Down
1 change: 1 addition & 0 deletions lite/core/mir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ lite_cc_library(mir_passes
fusion/sequence_reverse_embedding_fuse_pass.cc
fusion/instance_norm_activation_fuse_pass.cc
fusion/elementwise_add_scale_fuse_pass.cc
fusion/flatten_fc_fuse_pass.cc
fusion/fc_prelu_fuse_pass.cc
elimination/identity_scale_eliminate_pass.cc
elimination/identity_dropout_eliminate_pass.cc
Expand Down
4 changes: 4 additions & 0 deletions lite/core/mir/fusion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ lite_cc_library(fuse_instance_norm_activation
lite_cc_library(fuse_elementwise_add_scale
SRCS elementwise_add_scale_fuser.cc
DEPS pattern_matcher_high_api)
lite_cc_library(fuse_flatten_fc
SRCS flatten_fc_fuser.cc
DEPS pattern_matcher_high_api)
lite_cc_library(fuse_fc_prelu
SRCS fc_prelu_fuser.cc
DEPS pattern_matcher_high_api)
Expand Down Expand Up @@ -98,6 +101,7 @@ set(mir_fusers
fuse_sequence_reverse_embedding
fuse_instance_norm_activation
fuse_elementwise_add_scale
fuse_flatten_fc
fuse_fc_prelu
fuse_conv_scale
CACHE INTERNAL "fusers")
Expand Down
37 changes: 37 additions & 0 deletions lite/core/mir/fusion/flatten_fc_fuse_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/core/mir/fusion/flatten_fc_fuse_pass.h"
#include <memory>
#include <vector>
#include "lite/core/mir/fusion/flatten_fc_fuser.h"
#include "lite/core/mir/pass_registry.h"

namespace paddle {
namespace lite {
namespace mir {

void FlattenFcFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fusion::FlattenFcFuser flatten_fuser(" ");
flatten_fuser(graph.get());
}

} // namespace mir
} // namespace lite
} // namespace paddle

REGISTER_MIR_PASS(lite_flatten_fc_fuse_pass,
paddle::lite::mir::FlattenFcFusePass)
.BindTargets({TARGET(kOpenCL)})
.BindKernel("fc");
32 changes: 32 additions & 0 deletions lite/core/mir/fusion/flatten_fc_fuse_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include <string>
#include "lite/core/mir/pass.h"

namespace paddle {
namespace lite {
namespace mir {

class FlattenFcFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};

} // namespace mir
} // namespace lite
} // namespace paddle
86 changes: 86 additions & 0 deletions lite/core/mir/fusion/flatten_fc_fuser.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/core/mir/fusion/flatten_fc_fuser.h"
#include <memory>
#include <vector>

namespace paddle {
namespace lite {
namespace mir {
namespace fusion {

void FlattenFcFuser::BuildPattern() {
// flatten_contiguous_range
PMNode* x = VarNode("x")
->assert_is_op_input("flatten_contiguous_range", "X")
->AsInput();
PMNode* flatten_contiguous_range =
OpNode("flatten_contiguous_range", "flatten_contiguous_range")
->AsIntermediate();
PMNode* out = VarNode("output")
->assert_is_op_output("flatten_contiguous_range", "Out")
->AsIntermediate();
PMNode* xshape =
VarNode("xshape")
->assert_is_op_output("flatten_contiguous_range", "XShape")
->AsIntermediate();

// fc
// PMNode* input = VarNode("input")->assert_is_op_input("fc",
// "Input")->AsIntermediate();
PMNode* weights =
VarNode("weights")->assert_is_op_input("fc", "W")->AsInput();
PMNode* bias = VarNode("bias")->assert_is_op_input("fc", "Bias")->AsInput();
PMNode* fc = OpNode("fc", "fc")->AsIntermediate();
PMNode* fc_out =
VarNode("fc_out")->assert_is_op_output("fc", "Out")->AsOutput();

// create topology.
std::vector<PMNode*> fc_inputs{bias, weights, out};
*x >> *flatten_contiguous_range >> *out;
*flatten_contiguous_range >> *xshape;
fc_inputs >> *fc >> *fc_out;
}

void FlattenFcFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto fc_op = LiteOpRegistry::Global().Create("fc");
auto fc_old = matched.at("fc")->stmt()->op();
auto* scope = fc_old->scope();
auto& valid_places = fc_old->valid_places();
fc_op->Attach(op_desc, scope);

auto* new_op_node = graph->GraphCreateInstructNode(fc_op, valid_places);

IR_NODE_LINK_TO(matched.at("x"), new_op_node);
IR_NODE_LINK_TO(matched.at("weights"), new_op_node);
IR_NODE_LINK_TO(matched.at("bias"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("fc_out"));
}

cpp::OpDesc FlattenFcFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc = *matched.at("fc")->stmt()->op_info();
op_desc.SetInput("Input", {matched.at("x")->arg()->name});
op_desc.SetOutput("Out", {matched.at("fc_out")->arg()->name});
int in_num_col_dim = 1;
op_desc.SetAttr("in_num_col_dims", in_num_col_dim);
return op_desc;
}

} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
39 changes: 39 additions & 0 deletions lite/core/mir/fusion/flatten_fc_fuser.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include <string>
#include "lite/core/mir/pattern_matcher_high_api.h"

namespace paddle {
namespace lite {
namespace mir {
namespace fusion {

class FlattenFcFuser : public FuseBase {
public:
explicit FlattenFcFuser(const std::string& type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;

private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
};

} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
1 change: 1 addition & 0 deletions lite/core/optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class Optimizer {
"lite_scale_activation_fuse_pass", //
"lite_elementwise_scale_fuse_pass", //
"lite_instance_norm_activation_fuse_pass", //
"lite_flatten_fc_fuse_pass", //
"lite_fc_prelu_fuse_pass", //
"lite_elementwise_activation_fuse_pass",
"lite_conv_scale_fuse_pass",
Expand Down