Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cherry-pick][Pass] support transpose_Y in matmul_elt_add_fuse_pass #7178

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,17 @@ void MatmulElementwiseAddFusePass::Apply(
}
#if defined(LITE_WITH_X86) || defined(LITE_WITH_CUDA) || defined(LITE_WITH_ARM)
#ifdef LITE_WITH_MLU
fusion::MatmulElementwiseAddFuser fuser(false);
fusion::MatmulElementwiseAddFuser fuser(false, graph);
fuser(graph.get());
#else
fusion::MatmulElementwiseAddFuser fuser(true);
fusion::MatmulElementwiseAddFuser fuser(true, graph);
fuser(graph.get());
#endif
#endif
fusion::MatmulElementwiseAddFuser fuser2(false);
fusion::MatmulElementwiseAddFuser fuser2(false, graph);
fuser2(graph.get());
#ifdef LITE_WITH_FPGA
fusion::MatmulElementwiseAddFuser fpga_fuser(true);
fusion::MatmulElementwiseAddFuser fpga_fuser(true, graph);
fpga_fuser(graph.get());
#endif
}
Expand Down
56 changes: 47 additions & 9 deletions lite/core/optimizer/mir/fusion/matmul_elementwise_add_fuser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,13 @@ namespace lite {
namespace mir {
namespace fusion {

void MatmulElementwiseAddFuser::BuildPattern() {
void MatmulElementwiseAddFuser::CreatePattern() {
// create nodes.
auto* x = VarNode("x")->assert_is_op_input("matmul", "X");
auto* W = VarNode("W")->assert_is_persistable_var()->assert_is_op_input(
"matmul", "Y");
auto* b = VarNode("b")->assert_is_persistable_var();
/*
* The mul op must satisfy the following conditions:
* 1. the transpose_X and transpose_Y attrs are false
* 2. the alpha attr is 1.0
*/
auto* matmul = OpNode("matmul", "matmul")
->assert_op_attr<bool>("transpose_X", false)
->assert_op_attr<bool>("transpose_Y", false)
->assert_op_attr_satisfied<float>("alpha", [](float attr) {
return (std::fabs(attr - 1.0) < 1e-5);
});
Expand Down Expand Up @@ -66,6 +59,27 @@ void MatmulElementwiseAddFuser::BuildPattern() {
}
}

void MatmulElementwiseAddFuser::BuildPattern() {
for (auto& node : graph_->StmtTopologicalOrder()) {
if (node->IsStmt() &&
node->AsStmt().picked_kernel().op_type() == "matmul") {
auto* scope = node->stmt()->op()->scope();
auto op_desc = node->stmt()->mutable_op_info();

bool transpose_x = op_desc->GetAttr<bool>("transpose_X");
bool transpose_y = op_desc->GetAttr<bool>("transpose_Y");
auto arg_y_name = op_desc->Input("Y").front();
auto& tensor_y = scope->FindVar(arg_y_name)->Get<lite::Tensor>();
bool is_persist = tensor_y.persistable();
if ((!transpose_x && !transpose_y) ||
(!transpose_x && transpose_y && is_persist)) {
CreatePattern();
return;
}
}
}
}

void MatmulElementwiseAddFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
Expand All @@ -83,6 +97,16 @@ void MatmulElementwiseAddFuser::InsertNewNode(SSAGraph* graph,
IR_NODE_LINK_TO(new_op_node, matched.at("Out"));
}

template <typename T>
void transpose(T* dst, const T* src, const int src_rows, const int src_cols) {
CHECK(src && dst && src_rows > 0 && src_cols > 0);
for (int r = 0; r < src_rows; ++r) {
for (int c = 0; c < src_cols; ++c) {
dst[c * src_rows + r] = src[r * src_cols + c];
}
}
}

cpp::OpDesc MatmulElementwiseAddFuser::GenOpDesc(const key2nodes_t& matched) {
auto op_desc = *matched.at("matmul")->stmt()->op_info();

Expand All @@ -95,7 +119,7 @@ cpp::OpDesc MatmulElementwiseAddFuser::GenOpDesc(const key2nodes_t& matched) {
op_desc.HasInputScale(input_y_name);
if (is_quantized_op) {
x_scale_vct = op_desc.GetInputScale(input_x_name);
y_scale_vct = op_desc.GetInputScale(op_desc.Input("Y").front());
y_scale_vct = op_desc.GetInputScale(input_y_name);
}
auto* scope = matched.at("matmul")->stmt()->op()->scope();
auto x_shape = scope->FindVar(input_x_name)->Get<lite::Tensor>().dims();
Expand All @@ -105,6 +129,20 @@ cpp::OpDesc MatmulElementwiseAddFuser::GenOpDesc(const key2nodes_t& matched) {
<< scope->FindVar(input_y_name)->Get<lite::Tensor>().dims();
VLOG(4) << "x_num_col_dims: " << x_num_col_dims;

bool transpose_y = op_desc.GetAttr<bool>("transpose_Y");
if (transpose_y) {
auto* y_t = scope->FindVar(input_y_name)->GetMutable<lite::Tensor>();
auto y_dims = y_t->dims();
Tensor y_t_tmp;
y_t_tmp.CopyDataFrom(*y_t); // in order to copy y_t's
// target_,lod_,precision_, etc,. to y_t_tmp
y_t_tmp.Resize({y_dims[1], y_dims[0]});
const float* src = y_t->data<float>();
float* dst = y_t_tmp.mutable_data<float>();
transpose<float>(dst, src, y_dims[0], y_dims[1]);
y_t->CopyDataFrom(y_t_tmp);
}

op_desc.mutable_inputs()->clear();
op_desc.mutable_outputs()->clear();
op_desc.SetType("fc");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,17 @@ namespace fusion {

class MatmulElementwiseAddFuser : public FuseBase {
public:
explicit MatmulElementwiseAddFuser(bool with_relu) : with_relu_(with_relu) {}
explicit MatmulElementwiseAddFuser(bool with_relu,
const std::unique_ptr<SSAGraph>& graph)
: with_relu_(with_relu), graph_(graph) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;

private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
void CreatePattern();
bool with_relu_;
const std::unique_ptr<SSAGraph>& graph_;
};

} // namespace fusion
Expand Down