From 4bead3f8e1c9527863fbe3b698a4ff5d9a653bfb Mon Sep 17 00:00:00 2001 From: zhaoyang-star Date: Mon, 11 Oct 2021 18:01:37 +0800 Subject: [PATCH 1/2] [Pass] support transpose_Y in matmul_elt_add_fuse_pass (#7162) --- .../matmul_elementwise_add_fuse_pass.cc | 8 +-- .../fusion/matmul_elementwise_add_fuser.cc | 56 ++++++++++++++++--- .../mir/fusion/matmul_elementwise_add_fuser.h | 6 +- 3 files changed, 56 insertions(+), 14 deletions(-) diff --git a/lite/core/optimizer/mir/fusion/matmul_elementwise_add_fuse_pass.cc b/lite/core/optimizer/mir/fusion/matmul_elementwise_add_fuse_pass.cc index 7847ae93557..aea54b711b6 100644 --- a/lite/core/optimizer/mir/fusion/matmul_elementwise_add_fuse_pass.cc +++ b/lite/core/optimizer/mir/fusion/matmul_elementwise_add_fuse_pass.cc @@ -31,17 +31,17 @@ void MatmulElementwiseAddFusePass::Apply( } #if defined(LITE_WITH_X86) || defined(LITE_WITH_CUDA) || defined(LITE_WITH_ARM) #ifdef LITE_WITH_MLU - fusion::MatmulElementwiseAddFuser fuser(false); + fusion::MatmulElementwiseAddFuser fuser(false, graph); fuser(graph.get()); #else - fusion::MatmulElementwiseAddFuser fuser(true); + fusion::MatmulElementwiseAddFuser fuser(true, graph); fuser(graph.get()); #endif #endif - fusion::MatmulElementwiseAddFuser fuser2(false); + fusion::MatmulElementwiseAddFuser fuser2(false, graph); fuser2(graph.get()); #ifdef LITE_WITH_FPGA - fusion::MatmulElementwiseAddFuser fpga_fuser(true); + fusion::MatmulElementwiseAddFuser fpga_fuser(true, graph); fpga_fuser(graph.get()); #endif } diff --git a/lite/core/optimizer/mir/fusion/matmul_elementwise_add_fuser.cc b/lite/core/optimizer/mir/fusion/matmul_elementwise_add_fuser.cc index 1788f5f5c1e..41a1e76bff2 100644 --- a/lite/core/optimizer/mir/fusion/matmul_elementwise_add_fuser.cc +++ b/lite/core/optimizer/mir/fusion/matmul_elementwise_add_fuser.cc @@ -22,20 +22,13 @@ namespace lite { namespace mir { namespace fusion { -void MatmulElementwiseAddFuser::BuildPattern() { +void MatmulElementwiseAddFuser::CreatePattern() { // create nodes. auto* x = VarNode("x")->assert_is_op_input("matmul", "X"); auto* W = VarNode("W")->assert_is_persistable_var()->assert_is_op_input( "matmul", "Y"); auto* b = VarNode("b")->assert_is_persistable_var(); - /* - * The mul op must satisfy the following conditions: - * 1. the transpose_X and transpose_Y attrs are false - * 2. the alpha attr is 1.0 - */ auto* matmul = OpNode("matmul", "matmul") - ->assert_op_attr("transpose_X", false) - ->assert_op_attr("transpose_Y", false) ->assert_op_attr_satisfied("alpha", [](float attr) { return (std::fabs(attr - 1.0) < 1e-5); }); @@ -66,6 +59,27 @@ void MatmulElementwiseAddFuser::BuildPattern() { } } +void MatmulElementwiseAddFuser::BuildPattern() { + for (auto& node : graph_->StmtTopologicalOrder()) { + if (node->IsStmt() && + node->AsStmt().picked_kernel().op_type() == "matmul") { + auto* scope = node->stmt()->op()->scope(); + auto op_desc = node->stmt()->mutable_op_info(); + + bool transpose_x = op_desc->GetAttr("transpose_X"); + bool transpose_y = op_desc->GetAttr("transpose_Y"); + auto arg_y_name = op_desc->Input("Y").front(); + auto& tensor_y = scope->FindVar(arg_y_name)->Get(); + bool is_persist = tensor_y.persistable(); + if ((!transpose_x && !transpose_y) || + (!transpose_x && transpose_y && is_persist)) { + CreatePattern(); + return; + } + } + } +} + void MatmulElementwiseAddFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { auto op_desc = GenOpDesc(matched); @@ -83,6 +97,16 @@ void MatmulElementwiseAddFuser::InsertNewNode(SSAGraph* graph, IR_NODE_LINK_TO(new_op_node, matched.at("Out")); } +template +void transpose(T* dst, const T* src, const int src_rows, const int src_cols) { + CHECK(src && dst && src_rows > 0 && src_cols > 0); + for (int r = 0; r < src_rows; ++r) { + for (int c = 0; c < src_cols; ++c) { + dst[c * src_rows + r] = src[r * src_cols + c]; + } + } +} + cpp::OpDesc MatmulElementwiseAddFuser::GenOpDesc(const key2nodes_t& matched) { auto op_desc = *matched.at("matmul")->stmt()->op_info(); @@ -95,7 +119,7 @@ cpp::OpDesc MatmulElementwiseAddFuser::GenOpDesc(const key2nodes_t& matched) { op_desc.HasInputScale(input_y_name); if (is_quantized_op) { x_scale_vct = op_desc.GetInputScale(input_x_name); - y_scale_vct = op_desc.GetInputScale(op_desc.Input("Y").front()); + y_scale_vct = op_desc.GetInputScale(input_y_name); } auto* scope = matched.at("matmul")->stmt()->op()->scope(); auto x_shape = scope->FindVar(input_x_name)->Get().dims(); @@ -105,6 +129,20 @@ cpp::OpDesc MatmulElementwiseAddFuser::GenOpDesc(const key2nodes_t& matched) { << scope->FindVar(input_y_name)->Get().dims(); VLOG(4) << "x_num_col_dims: " << x_num_col_dims; + bool transpose_y = op_desc.GetAttr("transpose_Y"); + if (transpose_y) { + auto* y_t = scope->FindVar(input_y_name)->GetMutable(); + auto y_dims = y_t->dims(); + Tensor y_t_tmp; + y_t_tmp.CopyDataFrom(*y_t); // in order to copy y_t's + // target_,lod_,precision_, etc,. to y_t_tmp + y_t_tmp.Resize({y_dims[1], y_dims[0]}); + const float* src = y_t->data(); + float* dst = y_t_tmp.mutable_data(); + transpose(dst, src, y_dims[0], y_dims[1]); + y_t->CopyDataFrom(y_t_tmp); + } + op_desc.mutable_inputs()->clear(); op_desc.mutable_outputs()->clear(); op_desc.SetType("fc"); diff --git a/lite/core/optimizer/mir/fusion/matmul_elementwise_add_fuser.h b/lite/core/optimizer/mir/fusion/matmul_elementwise_add_fuser.h index c4ed5935326..5409dfe6211 100644 --- a/lite/core/optimizer/mir/fusion/matmul_elementwise_add_fuser.h +++ b/lite/core/optimizer/mir/fusion/matmul_elementwise_add_fuser.h @@ -25,13 +25,17 @@ namespace fusion { class MatmulElementwiseAddFuser : public FuseBase { public: - explicit MatmulElementwiseAddFuser(bool with_relu) : with_relu_(with_relu) {} + explicit MatmulElementwiseAddFuser(bool with_relu, + const std::unique_ptr& graph) + : with_relu_(with_relu), graph_(graph) {} void BuildPattern() override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; private: cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + void CreatePattern(); bool with_relu_; + const std::unique_ptr& graph_; }; } // namespace fusion From ac98fce5449b8236571bba653007210684d80e7c Mon Sep 17 00:00:00 2001 From: zhaoyang-star Date: Mon, 11 Oct 2021 18:04:32 +0800 Subject: [PATCH 2/2] test=develop