Skip to content

Commit

Permalink
[xpu] multi_encoder supports no mask input, such as VIT
Browse files Browse the repository at this point in the history
  • Loading branch information
linwei210 committed Nov 17, 2022
1 parent c5d5851 commit f967627
Showing 1 changed file with 70 additions and 40 deletions.
110 changes: 70 additions & 40 deletions lite/core/optimizer/mir/fusion/__xpu__multi_encoder_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,22 @@ class XPUSingleEncoderFuser : public FuseBase {
const std::string& input_pos = "Y",
const std::string& qkv_ln_2_out_pos = "Y",
const std::string& matmul_type = "matmul",
const std::string& matmul2_type = "matmul_v2",
const std::string& mul_type = "mul",
bool with_q_scale = true,
bool norm_before = false,
const std::string& relative_type = "")
const std::string& relative_type = "",
bool with_mask = true)
: act_type_(act_type),
input_pos_(input_pos),
qkv_ln_2_out_pos_(qkv_ln_2_out_pos),
matmul_type_(matmul_type),
matmul2_type_(matmul2_type),
mul_type_(mul_type),
with_q_scale_(with_q_scale),
norm_before_(norm_before),
relative_emb_type_(relative_type) {}
relative_emb_type_(relative_type),
with_mask_(with_mask) {}

void BuildPattern() override {
auto* input = VarNode("input")
Expand Down Expand Up @@ -213,18 +217,25 @@ class XPUSingleEncoderFuser : public FuseBase {
->AsIntermediate();

auto* qk_matmul = OpNode("qk_matmul", matmul_type_)->AsIntermediate();
std::string op_after_qk_matmul = with_mask_ ? "elementwise_add" : "softmax";
auto* qk_matmul_out = VarNode("qk_matmul_out")
->assert_is_op_output(matmul_type_, "Out")
->assert_is_op_input("elementwise_add", "X")
->assert_is_op_input(op_after_qk_matmul, "X")
->AsIntermediate();
auto* qk_mask = VarNode("qk_mask")
->assert_is_op_input("elementwise_add", "Y")
->AsInput();
auto* qk_add = OpNode("qk_add", "elementwise_add")->AsIntermediate();
auto* qk_add_out = VarNode("qk_add_out")
->assert_is_op_output("elementwise_add", "Out")
->assert_is_op_input("softmax", "X")
->AsIntermediate();
PMNode* qk_mask = nullptr;
PMNode* qk_add = nullptr;
PMNode* qk_add_out = nullptr;
if (with_mask_) {
qk_mask = VarNode("qk_mask")
->assert_is_op_input("elementwise_add", "Y")
->AsInput();
qk_add = OpNode("qk_add", "elementwise_add")->AsIntermediate();
qk_add_out = VarNode("qk_add_out")
->assert_is_op_output("elementwise_add", "Out")
->assert_is_op_input("softmax", "X")
->AsIntermediate();
}

auto* qk_softmax = OpNode("qk_softmax", "softmax")->AsIntermediate();
auto* qk_softmax_out = VarNode("qk_softmax_out")
->assert_is_op_output("softmax", "Out")
Expand Down Expand Up @@ -256,16 +267,16 @@ class XPUSingleEncoderFuser : public FuseBase {
auto* v_transpose2 = OpNode("v_transpose2", "transpose2")->AsIntermediate();
auto* v_transpose2_out = VarNode("v_transpose2_out")
->assert_is_op_output("transpose2", "Out")
->assert_is_op_input(matmul_type_, "Y")
->assert_is_op_input(matmul2_type_, "Y")
->AsIntermediate();
auto* v_transpose2_xshape =
VarNode("v_transpose2_xshape")
->assert_is_op_output("transpose2", "XShape")
->AsIntermediate();

auto* qkv_matmul = OpNode("qkv_matmul", matmul_type_)->AsIntermediate();
auto* qkv_matmul = OpNode("qkv_matmul", matmul2_type_)->AsIntermediate();
auto* qkv_matmul_out = VarNode("qkv_matmul_out")
->assert_is_op_output(matmul_type_, "Out")
->assert_is_op_output(matmul2_type_, "Out")
->assert_is_op_input("transpose2", "X")
->AsIntermediate();
auto* qkv_transpose2 =
Expand Down Expand Up @@ -459,9 +470,14 @@ class XPUSingleEncoderFuser : public FuseBase {
*k_reshape2 >> *k_reshape2_xshape;
*k_transpose2 >> *k_transpose2_xshape;

*qk_matmul >> *qk_matmul_out >> *qk_add >> *qk_add_out >> *qk_softmax >>
*qk_softmax_out >> *qkv_matmul;
*qk_mask >> *qk_add;
if (with_mask_) {
*qk_matmul >> *qk_matmul_out >> *qk_add >> *qk_add_out >> *qk_softmax >>
*qk_softmax_out >> *qkv_matmul;
*qk_mask >> *qk_add;
} else {
*qk_matmul >> *qk_matmul_out >> *qk_softmax >> *qk_softmax_out >>
*qkv_matmul;
}

if (norm_before_) {
*ln_before_out >> *v_mul;
Expand Down Expand Up @@ -513,7 +529,9 @@ class XPUSingleEncoderFuser : public FuseBase {
cpp::OpDesc op_desc;
op_desc.SetType("single_encoder");
op_desc.SetInput("Inputs", {matched.at("input")->arg()->name});
op_desc.SetInput("Mask", {matched.at("qk_mask")->arg()->name});
if (with_mask_) {
op_desc.SetInput("Mask", {matched.at("qk_mask")->arg()->name});
}
op_desc.SetInput("FCWeight",
{
matched.at("q_mul_y")->arg()->name,
Expand Down Expand Up @@ -645,7 +663,6 @@ class XPUSingleEncoderFuser : public FuseBase {
single_encoder_stmt->SetOp(fake_subgraph_op);

std::vector<std::string> froms = {
"qk_mask",
"k_mul_y",
"v_mul_y",
"qkv_mul_y",
Expand All @@ -660,6 +677,9 @@ class XPUSingleEncoderFuser : public FuseBase {
"qkv_ln_2_scale",
"qkv_ln_2_bias",
};
if (with_mask_) {
froms.push_back("qk_mask");
}
if (relative_emb_type_ == "__xpu__roformer_relative_embedding") {
froms.push_back("q_cos_embedding");
froms.push_back("q_sin_embedding");
Expand Down Expand Up @@ -687,10 +707,12 @@ class XPUSingleEncoderFuser : public FuseBase {
std::string input_pos_;
std::string qkv_ln_2_out_pos_;
std::string matmul_type_;
std::string matmul2_type_;
std::string mul_type_;
bool with_q_scale_;
bool norm_before_;
const std::string relative_emb_type_;
bool with_mask_;
// quant_info: mul input_max, output_max * 6 + matmul x_max:y_max, output_max
// * 2
void set_quant_info(Scope* scope,
Expand Down Expand Up @@ -955,7 +977,7 @@ class XPUMultiEncoderFuser {
std::string mask_name;
for (auto* encoder : all_encoders) {
auto* op_info = encoder->stmt()->op_info();
if (mask_name.empty()) {
if (mask_name.empty() && op_info->HasInput("Mask")) {
mask_name = op_info->Input("Mask").front();
} else {
// CHECK(mask_name == op_info->Input("Mask").front());
Expand Down Expand Up @@ -1026,13 +1048,11 @@ class XPUMultiEncoderFuser {
if (all_encoders.size() == 1) {
// take care of only one encoder
in_name = op_info->Input("Inputs").front();
mask_name = op_info->Input("Mask").front();
out_name = op_info->Output("Outputs").front();
} else if (i == 0) {
// first encoder
to_remove.insert(cur_out);
in_name = op_info->Input("Inputs").front();
mask_name = op_info->Input("Mask").front();
} else if (i == all_encoders.size() - 1) {
// last encoder
to_remove.insert(cur_encoder);
Expand All @@ -1051,7 +1071,9 @@ class XPUMultiEncoderFuser {
for (auto kv : arg_map) {
op_desc.SetInput(kv.first, kv.second);
}
op_desc.SetInput("Mask", {mask_name});
if (!mask_name.empty()) {
op_desc.SetInput("Mask", {mask_name});
}
op_desc.SetOutput("Output", {out_name});
op_desc.SetAttr<int>("xpu", 1);
op_desc.SetAttr<int>(
Expand Down Expand Up @@ -1382,9 +1404,11 @@ class XPUMultiEncoderFusePass : public ProgramPass {
std::vector<std::string> input_poss{"X", "Y"};
std::vector<std::string> qkv_ln_2_out_poss{"X", "Y"};
std::vector<std::string> matmul_types{"matmul", "matmul_v2"};
std::vector<std::string> matmul2_types{"matmul", "matmul_v2"};
std::vector<std::string> mul_types{"mul", "matmul", "matmul_v2"};
std::vector<bool> with_q_scales{true, false};
std::vector<bool> norm_befores{true, false};
std::vector<bool> with_mask{true, false};
std::vector<std::string> relative_embedding_type{
"", "__xpu__roformer_relative_embedding"};

Expand Down Expand Up @@ -1423,23 +1447,29 @@ class XPUMultiEncoderFusePass : public ProgramPass {
for (auto& input_pos : input_poss) {
for (auto& qkv_ln_2_out_pos : qkv_ln_2_out_poss) {
for (auto& matmul_type : matmul_types) {
for (auto& mul_type : mul_types) {
for (auto with_q_scale : with_q_scales) {
for (auto norm_before : norm_befores) {
for (auto relative_type : relative_embedding_type) {
fusion::XPUSingleEncoderFuser single_encoder_fuser(
act_type,
input_pos,
qkv_ln_2_out_pos,
matmul_type,
mul_type,
with_q_scale,
norm_before,
relative_type);
single_encoder_fuser(graph.get());
fusion::XPUMultiEncoderFuser multi_encoder_fuser(
fc_precision, adaptive_seqlen);
multi_encoder_fuser(graph.get());
for (auto& matmul2_type : matmul2_types) {
for (auto& mul_type : mul_types) {
for (auto with_q_scale : with_q_scales) {
for (auto norm_before : norm_befores) {
for (auto relative_type : relative_embedding_type) {
for (auto mask : with_mask) {
fusion::XPUSingleEncoderFuser single_encoder_fuser(
act_type,
input_pos,
qkv_ln_2_out_pos,
matmul_type,
matmul2_type,
mul_type,
with_q_scale,
norm_before,
relative_type,
mask);
single_encoder_fuser(graph.get());
fusion::XPUMultiEncoderFuser multi_encoder_fuser(
fc_precision, adaptive_seqlen);
multi_encoder_fuser(graph.get());
}
}
}
}
}
Expand Down

0 comments on commit f967627

Please sign in to comment.