Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Arm] Fix fuse_attention support old_quant_format #10027

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions lite/backends/arm/math/conv_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,6 @@ void conv1x1s1_gemm_int8(const int8_t* i_data,
n,
k,
flag_bias,
GemmMBias,
false,
scale_group,
act_param,
Expand Down Expand Up @@ -1605,7 +1604,6 @@ void conv_im2col_gemm_int8(const int8_t* i_data,
n,
k,
flag_bias,
GemmMBias,
false,
scale_group,
act_param,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,28 @@ namespace mir {

void TransformerAttentionFusePass::Apply(
const std::unique_ptr<SSAGraph>& graph) {
fusion::TransformerAttentionFuser fuser;
bool has_int8 = false;
for (auto& place : graph->valid_places()) {
if (place.precision == PRECISION(kInt8)) {
has_int8 = true;
}
}
if ((has_int8)) {
fuser(graph.get());
} else {
return;
std::vector<bool> reshape_has_xshapes = {false, true};
std::vector<bool> transpose_has_xshapes = {false, true};
std::vector<bool> dropout_masks = {false, true};
std::vector<std::string> mul_types = {"matmul", "matmul_v2"};
for (auto reshape_has_xshape : reshape_has_xshapes) {
for (auto transpose_has_xshape : transpose_has_xshapes) {
for (auto dropout_mask : dropout_masks) {
for (auto mul_type : mul_types) {
fusion::TransformerAttentionFuser fuser(
reshape_has_xshape, transpose_has_xshape, dropout_mask, mul_type);
if ((has_int8)) {
fuser(graph.get());
}
}
}
}
}
}

Expand Down
168 changes: 112 additions & 56 deletions lite/core/optimizer/mir/fusion/transformer_attention_fuser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,19 @@ namespace fusion {
* scale | |
* \ / |
* \ / |
* matmul_v2 /
* \ /
* \ /
* elementwise_add /
* \ /
* \ /
* softmax /
* \ /
* \ /
* matmul_v2
* matmul_v2/matmul |
* \ /
* \ /
* elementwise_add /
* \ /
* \ /
* softmax /
* | /
* | /
* dropout /
* \ /
* \ /
* matmul_v2/matmul
* |
* |
* output
Expand All @@ -56,20 +59,33 @@ namespace fusion {
void TransformerAttentionFuser::BuildPattern() {
auto matmul0_attr_teller = [](const Node* node) -> bool {
auto op_desc = *const_cast<Node*>(node)->stmt()->op_info();
auto trans_x = op_desc.GetAttr<bool>("trans_x");
auto trans_y = op_desc.GetAttr<bool>("trans_y");
bool trans_x;
bool trans_y;
if (op_desc.Type() == "matmul") {
trans_x = op_desc.GetAttr<bool>("transpose_X");
trans_y = op_desc.GetAttr<bool>("transpose_Y");
} else {
trans_x = op_desc.GetAttr<bool>("trans_x");
trans_y = op_desc.GetAttr<bool>("trans_y");
}
auto res = (trans_x == false && trans_y == true);
return res;
};
auto matmul1_attr_teller = [](const Node* node) -> bool {
auto op_desc = *const_cast<Node*>(node)->stmt()->op_info();
auto trans_x = op_desc.GetAttr<bool>("trans_x");
auto trans_y = op_desc.GetAttr<bool>("trans_y");
bool trans_x;
bool trans_y;
if (op_desc.Type() == "matmul") {
trans_x = op_desc.GetAttr<bool>("transpose_X");
trans_y = op_desc.GetAttr<bool>("transpose_Y");
} else {
trans_x = op_desc.GetAttr<bool>("trans_x");
trans_y = op_desc.GetAttr<bool>("trans_y");
}
auto res = (trans_x == false && trans_y == false);
return res;
};
auto* input0 =
VarNode("input0")->assert_is_op_input("fc", "Input")->AsInput();
auto* input = VarNode("input")->assert_is_op_input("fc", "Input")->AsInput();
// fc
auto* fc0_w = VarNode("fc0_w")->assert_is_op_input("fc", "W");
auto* fc0_bias = VarNode("fc0_bias")->assert_is_op_input("fc", "Bias");
Expand Down Expand Up @@ -99,44 +115,54 @@ void TransformerAttentionFuser::BuildPattern() {
auto* reshape2_out =
VarNode("reshape2_out")->assert_is_op_output("reshape2", "Out");

auto* xshape0 = VarNode("xshape0")->assert_is_op_output("reshape2", "XShape");
auto* xshape1 = VarNode("xshape1")->assert_is_op_output("reshape2", "XShape");
auto* xshape2 = VarNode("xshape2")->assert_is_op_output("reshape2", "XShape");
PMNode* xshape0 = nullptr;
PMNode* xshape1 = nullptr;
PMNode* xshape2 = nullptr;
if (reshape_has_xshape_) {
xshape0 = VarNode("xshape0")->assert_is_op_output("reshape2", "XShape");
xshape1 = VarNode("xshape1")->assert_is_op_output("reshape2", "XShape");
xshape2 = VarNode("xshape2")->assert_is_op_output("reshape2", "XShape");
}

// transpose2
auto* transpose0 = OpNode("transpose0", "transpose2")
->assert_op_attr("axis", std::vector<int>{0, 2, 1, 3});
auto* transpose0_out =
VarNode("transpose0_out")->assert_is_op_output("transpose2", "Out");
auto* xshape3 =
VarNode("xshape3")->assert_is_op_output("transpose2", "XShape");

auto* transpose1 = OpNode("transpose1", "transpose2")
->assert_op_attr("axis", std::vector<int>{0, 2, 1, 3});
auto* transpose1_out =
VarNode("transpose1_out")->assert_is_op_output("transpose2", "Out");
auto* xshape4 =
VarNode("xshape4")->assert_is_op_output("transpose2", "XShape");

auto* transpose2 = OpNode("transpose2", "transpose2")
->assert_op_attr("axis", std::vector<int>{0, 2, 1, 3});
auto* transpose2_out =
VarNode("transpose2_out")->assert_is_op_output("transpose2", "Out");
auto* xshape5 =
VarNode("xshape5")->assert_is_op_output("transpose2", "XShape");

PMNode* xshape3 = nullptr;
PMNode* xshape4 = nullptr;
PMNode* xshape5 = nullptr;
if (transpose_has_xshape_) {
xshape3 = VarNode("xshape3")->assert_is_op_output("transpose2", "XShape");
xshape4 = VarNode("xshape4")->assert_is_op_output("transpose2", "XShape");
xshape5 = VarNode("xshape5")->assert_is_op_output("transpose2", "XShape");
}

// scale
auto* scale0 = OpNode("scale0", "scale");
auto* scale0_out = VarNode("scale0_out")->assert_is_op_output("scale", "Out");

// matmul_v2
auto* matmul0 = OpNode("matmul0", "matmul_v2")
->assert_node_satisfied(matmul0_attr_teller);
// matmul
auto* matmul0 =
OpNode("matmul0", mul_type_)->assert_node_satisfied(matmul0_attr_teller);
auto* matmul0_out =
VarNode("matmul0_out")->assert_is_op_output("matmul_v2", "Out");
VarNode("matmul0_out")->assert_is_op_output(mul_type_, "Out");

// elementwise_add
auto* input1 =
VarNode("input1")->assert_is_op_input("elementwise_add", "Y")->AsInput();
auto* residual = VarNode("residual")
->assert_is_op_input("elementwise_add", "Y")
->AsInput();
auto* add = OpNode("add", "elementwise_add");
auto* add0_out =
VarNode("add0_out")->assert_is_op_output("elementwise_add", "Out");
Expand All @@ -146,42 +172,67 @@ void TransformerAttentionFuser::BuildPattern() {
auto* softmax0_out =
VarNode("softmax0_out")->assert_is_op_output("softmax", "Out");

// matmul_v2
auto* matmul1 = OpNode("matmul1", "matmul_v2")
->assert_node_satisfied(matmul1_attr_teller);
// dropout
auto* dropout = OpNode("dropout", "dropout");
auto* dropout_out =
VarNode("dropout_out")->assert_is_op_output("dropout", "Out");
PMNode* mask_out = nullptr;
if (dropout_mask_) {
mask_out = VarNode("mask_out")->assert_is_op_output("dropout", "Mask");
}

// matmul
auto* matmul1 =
OpNode("matmul1", mul_type_)->assert_node_satisfied(matmul1_attr_teller);

auto* Out = VarNode("Out");

std::vector<PMNode*> fc0_inputs{input0, fc0_w, fc0_bias};
std::vector<PMNode*> fc1_inputs{input0, fc1_w, fc1_bias};
std::vector<PMNode*> fc2_inputs{input0, fc2_w, fc2_bias};
std::vector<PMNode*> fc0_inputs{input, fc0_w, fc0_bias};
std::vector<PMNode*> fc1_inputs{input, fc1_w, fc1_bias};
std::vector<PMNode*> fc2_inputs{input, fc2_w, fc2_bias};
fc0_inputs >> *fc0 >> *fc0_out >> *reshape0 >> *reshape0_out >> *transpose0 >>
*transpose0_out >> *scale0 >> *scale0_out;
fc1_inputs >> *fc1 >> *fc1_out >> *reshape1 >> *reshape1_out >> *transpose1 >>
*transpose1_out;
fc2_inputs >> *fc2 >> *fc2_out >> *reshape2 >> *reshape2_out >> *transpose2 >>
*transpose2_out;
*reshape0 >> *xshape0;
*reshape1 >> *xshape1;
*reshape2 >> *xshape2;
*transpose0 >> *xshape3;
*transpose1 >> *xshape4;
*transpose2 >> *xshape5;
if (reshape_has_xshape_) {
*reshape0 >> *xshape0;
*reshape1 >> *xshape1;
*reshape2 >> *xshape2;
}
if (transpose_has_xshape_) {
*transpose0 >> *xshape3;
*transpose1 >> *xshape4;
*transpose2 >> *xshape5;
}

std::vector<PMNode*> matmul0_inputs{scale0_out, transpose1_out};
matmul0_inputs >> *matmul0 >> *matmul0_out;
std::vector<PMNode*> add0_inputs{matmul0_out, input1};
add0_inputs >> *add >> *add0_out >> *softmax0 >> *softmax0_out;
std::vector<PMNode*> add0_inputs{matmul0_out, residual};
add0_inputs >> *add >> *add0_out >> *softmax0 >> *softmax0_out >> *dropout >>
*dropout_out;

if (dropout_mask_) {
*dropout >> *mask_out;
}

std::vector<PMNode*> matmul1_inputs{softmax0_out, transpose2_out};
std::vector<PMNode*> matmul1_inputs{dropout_out, transpose2_out};
matmul1_inputs >> *matmul1 >> *Out;

xshape0->AsIntermediate();
xshape1->AsIntermediate();
xshape2->AsIntermediate();
xshape3->AsIntermediate();
xshape4->AsIntermediate();
xshape5->AsIntermediate();
if (reshape_has_xshape_) {
xshape0->AsIntermediate();
xshape1->AsIntermediate();
xshape2->AsIntermediate();
}
if (transpose_has_xshape_) {
xshape3->AsIntermediate();
xshape4->AsIntermediate();
xshape5->AsIntermediate();
}
if (dropout_mask_) {
mask_out->AsIntermediate();
}
fc0->AsIntermediate();
fc0_out->AsIntermediate();
reshape0->AsIntermediate();
Expand All @@ -208,6 +259,8 @@ void TransformerAttentionFuser::BuildPattern() {
add0_out->AsIntermediate();
softmax0->AsIntermediate();
softmax0_out->AsIntermediate();
dropout->AsIntermediate();
dropout_out->AsIntermediate();
matmul1->AsIntermediate();
}

Expand Down Expand Up @@ -312,8 +365,8 @@ void TransformerAttentionFuser::InsertNewNode(SSAGraph* graph,
auto* scope = fc->scope();

// set input
op_desc.SetInput("Input0", {matched.at("input0")->arg()->name});
op_desc.SetInput("Input1", {matched.at("input1")->arg()->name});
op_desc.SetInput("Input", {matched.at("input")->arg()->name});
op_desc.SetInput("Residual", {matched.at("residual")->arg()->name});

// fc
auto fc0_op_desc = matched.at("fc0")->stmt()->op_info();
Expand Down Expand Up @@ -393,6 +446,7 @@ void TransformerAttentionFuser::InsertNewNode(SSAGraph* graph,
scale0_scale,
bias0_dims[0]);
op_desc.SetAttr<std::vector<float>>("fc0_scale", fuse_scales);
op_desc.SetAttr<std::vector<float>>("Input0_scale", fc0_scale_x);
// fc 1
auto matmul0_scale_x =
matmul0_op_desc->GetAttr<std::vector<float>>("X0_scale");
Expand Down Expand Up @@ -424,6 +478,8 @@ void TransformerAttentionFuser::InsertNewNode(SSAGraph* graph,
op_desc.SetInput("Bias", {matched.at("fc0_bias")->arg()->name});
op_desc.SetAttr<std::string>("op_type",
fc0_op_desc->GetAttr<std::string>("op_type"));
op_desc.SetAttr<int32_t>("in_num_col_dims",
fc0_op_desc->GetAttr<int32_t>("in_num_col_dims"));
// reshape
auto reshape_op_desc = matched.at("reshape0")->stmt()->op_info();
op_desc.SetAttr<std::vector<int>>(
Expand All @@ -440,8 +496,8 @@ void TransformerAttentionFuser::InsertNewNode(SSAGraph* graph,
fused_attention_op->Attach(op_desc, scope);
auto* new_op_node =
graph->GraphCreateInstructNode(fused_attention_op, valid_places);
DirectedLink(matched.at("input0"), new_op_node);
DirectedLink(matched.at("input1"), new_op_node);
DirectedLink(matched.at("input"), new_op_node);
DirectedLink(matched.at("residual"), new_op_node);
DirectedLink(matched.at("fc0_w"), new_op_node);
DirectedLink(matched.at("fc0_bias"), new_op_node);
DirectedLink(new_op_node, matched.at("Out"));
Expand Down
12 changes: 12 additions & 0 deletions lite/core/optimizer/mir/fusion/transformer_attention_fuser.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,22 @@ namespace fusion {

class TransformerAttentionFuser : public FuseBase {
public:
explicit TransformerAttentionFuser(bool reshape_has_xshape,
bool transpose_has_xshape,
bool dropout_mask,
std::string mul_type)
: reshape_has_xshape_(reshape_has_xshape),
transpose_has_xshape_(transpose_has_xshape),
dropout_mask_(dropout_mask),
mul_type_(mul_type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;

private:
bool reshape_has_xshape_;
bool transpose_has_xshape_;
bool dropout_mask_;
std::string mul_type_;
};

} // namespace fusion
Expand Down
2 changes: 1 addition & 1 deletion lite/core/optimizer/optimizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,11 @@ std::unique_ptr<RuntimeProgram> RunDefaultOptimizer(
"lite_elementwise_activation_fuse_pass",
"lite_conv_scale_fuse_pass",
"lite_conv_elementwise_tree_fuse_pass",
"transformer_attention_fuse_pass",
"lite_greater_than_cast_fuse_pass",
"identity_dropout_eliminate_pass",
"sparse_conv_detect_pass",
// "keepdims_convert_pass",
"transformer_attention_fuse_pass",
"__xpu__max_pooling_pad_zero_detect_fuse_pass",
"__xpu__graph_dedup_pass",
"__xpu__resnet_fuse_pass",
Expand Down
Loading