From 0bbb7752b5eb2f4582280e7f701955d18522b07a Mon Sep 17 00:00:00 2001 From: daming5432 Date: Tue, 11 May 2021 14:12:51 +0800 Subject: [PATCH] [pass][OpenCL]add fuse flatten_contiguous_range and fc pass (#6040) * add fuse flatten_contiguous_range and fc pass * modify Copyright 2021 test=develop --- lite/api/paddle_use_passes.h | 1 + lite/core/mir/CMakeLists.txt | 1 + lite/core/mir/fusion/CMakeLists.txt | 4 + lite/core/mir/fusion/flatten_fc_fuse_pass.cc | 37 +++++++++ lite/core/mir/fusion/flatten_fc_fuse_pass.h | 32 ++++++++ lite/core/mir/fusion/flatten_fc_fuser.cc | 86 ++++++++++++++++++++ lite/core/mir/fusion/flatten_fc_fuser.h | 39 +++++++++ lite/core/optimizer.h | 1 + 8 files changed, 201 insertions(+) create mode 100644 lite/core/mir/fusion/flatten_fc_fuse_pass.cc create mode 100644 lite/core/mir/fusion/flatten_fc_fuse_pass.h create mode 100644 lite/core/mir/fusion/flatten_fc_fuser.cc create mode 100644 lite/core/mir/fusion/flatten_fc_fuser.h diff --git a/lite/api/paddle_use_passes.h b/lite/api/paddle_use_passes.h index 5c55875c68d..2f93332b34f 100644 --- a/lite/api/paddle_use_passes.h +++ b/lite/api/paddle_use_passes.h @@ -78,6 +78,7 @@ USE_MIR_PASS(control_flow_op_shared_inputs_and_outputs_place_sync_pass); USE_MIR_PASS(lite_scale_activation_fuse_pass); USE_MIR_PASS(lite_instance_norm_activation_fuse_pass); USE_MIR_PASS(ssd_boxes_calc_offline_pass); +USE_MIR_PASS(lite_flatten_fc_fuse_pass); USE_MIR_PASS(lite_fc_prelu_fuse_pass); USE_MIR_PASS(__xpu__graph_dedup_pass); USE_MIR_PASS(__xpu__resnet_fuse_pass); diff --git a/lite/core/mir/CMakeLists.txt b/lite/core/mir/CMakeLists.txt index c4d22d7885c..c60a565ec15 100644 --- a/lite/core/mir/CMakeLists.txt +++ b/lite/core/mir/CMakeLists.txt @@ -58,6 +58,7 @@ lite_cc_library(mir_passes fusion/sequence_reverse_embedding_fuse_pass.cc fusion/instance_norm_activation_fuse_pass.cc fusion/elementwise_add_scale_fuse_pass.cc + fusion/flatten_fc_fuse_pass.cc fusion/fc_prelu_fuse_pass.cc elimination/identity_scale_eliminate_pass.cc elimination/identity_dropout_eliminate_pass.cc diff --git a/lite/core/mir/fusion/CMakeLists.txt b/lite/core/mir/fusion/CMakeLists.txt index 78920ed241f..60832c54e3f 100644 --- a/lite/core/mir/fusion/CMakeLists.txt +++ b/lite/core/mir/fusion/CMakeLists.txt @@ -70,6 +70,9 @@ lite_cc_library(fuse_instance_norm_activation lite_cc_library(fuse_elementwise_add_scale SRCS elementwise_add_scale_fuser.cc DEPS pattern_matcher_high_api) +lite_cc_library(fuse_flatten_fc + SRCS flatten_fc_fuser.cc + DEPS pattern_matcher_high_api) lite_cc_library(fuse_fc_prelu SRCS fc_prelu_fuser.cc DEPS pattern_matcher_high_api) @@ -98,6 +101,7 @@ set(mir_fusers fuse_sequence_reverse_embedding fuse_instance_norm_activation fuse_elementwise_add_scale + fuse_flatten_fc fuse_fc_prelu fuse_conv_scale CACHE INTERNAL "fusers") diff --git a/lite/core/mir/fusion/flatten_fc_fuse_pass.cc b/lite/core/mir/fusion/flatten_fc_fuse_pass.cc new file mode 100644 index 00000000000..f8bbd1fc8af --- /dev/null +++ b/lite/core/mir/fusion/flatten_fc_fuse_pass.cc @@ -0,0 +1,37 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/flatten_fc_fuse_pass.h" +#include +#include +#include "lite/core/mir/fusion/flatten_fc_fuser.h" +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void FlattenFcFusePass::Apply(const std::unique_ptr& graph) { + fusion::FlattenFcFuser flatten_fuser(" "); + flatten_fuser(graph.get()); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(lite_flatten_fc_fuse_pass, + paddle::lite::mir::FlattenFcFusePass) + .BindTargets({TARGET(kOpenCL)}) + .BindKernel("fc"); diff --git a/lite/core/mir/fusion/flatten_fc_fuse_pass.h b/lite/core/mir/fusion/flatten_fc_fuse_pass.h new file mode 100644 index 00000000000..79e0e442d2e --- /dev/null +++ b/lite/core/mir/fusion/flatten_fc_fuse_pass.h @@ -0,0 +1,32 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +class FlattenFcFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/flatten_fc_fuser.cc b/lite/core/mir/fusion/flatten_fc_fuser.cc new file mode 100644 index 00000000000..ee8e3bd9e51 --- /dev/null +++ b/lite/core/mir/fusion/flatten_fc_fuser.cc @@ -0,0 +1,86 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/flatten_fc_fuser.h" +#include +#include + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +void FlattenFcFuser::BuildPattern() { + // flatten_contiguous_range + PMNode* x = VarNode("x") + ->assert_is_op_input("flatten_contiguous_range", "X") + ->AsInput(); + PMNode* flatten_contiguous_range = + OpNode("flatten_contiguous_range", "flatten_contiguous_range") + ->AsIntermediate(); + PMNode* out = VarNode("output") + ->assert_is_op_output("flatten_contiguous_range", "Out") + ->AsIntermediate(); + PMNode* xshape = + VarNode("xshape") + ->assert_is_op_output("flatten_contiguous_range", "XShape") + ->AsIntermediate(); + + // fc + // PMNode* input = VarNode("input")->assert_is_op_input("fc", + // "Input")->AsIntermediate(); + PMNode* weights = + VarNode("weights")->assert_is_op_input("fc", "W")->AsInput(); + PMNode* bias = VarNode("bias")->assert_is_op_input("fc", "Bias")->AsInput(); + PMNode* fc = OpNode("fc", "fc")->AsIntermediate(); + PMNode* fc_out = + VarNode("fc_out")->assert_is_op_output("fc", "Out")->AsOutput(); + + // create topology. + std::vector fc_inputs{bias, weights, out}; + *x >> *flatten_contiguous_range >> *out; + *flatten_contiguous_range >> *xshape; + fc_inputs >> *fc >> *fc_out; +} + +void FlattenFcFuser::InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched) { + auto op_desc = GenOpDesc(matched); + auto fc_op = LiteOpRegistry::Global().Create("fc"); + auto fc_old = matched.at("fc")->stmt()->op(); + auto* scope = fc_old->scope(); + auto& valid_places = fc_old->valid_places(); + fc_op->Attach(op_desc, scope); + + auto* new_op_node = graph->GraphCreateInstructNode(fc_op, valid_places); + + IR_NODE_LINK_TO(matched.at("x"), new_op_node); + IR_NODE_LINK_TO(matched.at("weights"), new_op_node); + IR_NODE_LINK_TO(matched.at("bias"), new_op_node); + IR_NODE_LINK_TO(new_op_node, matched.at("fc_out")); +} + +cpp::OpDesc FlattenFcFuser::GenOpDesc(const key2nodes_t& matched) { + cpp::OpDesc op_desc = *matched.at("fc")->stmt()->op_info(); + op_desc.SetInput("Input", {matched.at("x")->arg()->name}); + op_desc.SetOutput("Out", {matched.at("fc_out")->arg()->name}); + int in_num_col_dim = 1; + op_desc.SetAttr("in_num_col_dims", in_num_col_dim); + return op_desc; +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/flatten_fc_fuser.h b/lite/core/mir/fusion/flatten_fc_fuser.h new file mode 100644 index 00000000000..151ce6267dc --- /dev/null +++ b/lite/core/mir/fusion/flatten_fc_fuser.h @@ -0,0 +1,39 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +class FlattenFcFuser : public FuseBase { + public: + explicit FlattenFcFuser(const std::string& type) {} + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; +}; + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/optimizer.h b/lite/core/optimizer.h index 1a7ebbe1ef5..41619c32448 100644 --- a/lite/core/optimizer.h +++ b/lite/core/optimizer.h @@ -113,6 +113,7 @@ class Optimizer { "lite_scale_activation_fuse_pass", // "lite_elementwise_scale_fuse_pass", // "lite_instance_norm_activation_fuse_pass", // + "lite_flatten_fc_fuse_pass", // "lite_fc_prelu_fuse_pass", // "lite_elementwise_activation_fuse_pass", "lite_conv_scale_fuse_pass",