From efa834875e137178372021d4e529274a99e1f9ae Mon Sep 17 00:00:00 2001 From: linwei Date: Thu, 17 Nov 2022 17:58:04 +0800 Subject: [PATCH] [xpu] multi_encoder supports no mask input, such as VIT --- .../fusion/__xpu__multi_encoder_fuse_pass.cc | 110 +++++++++++------- ...xpu__multi_encoder_slice_link_fuse_pass.cc | 3 +- .../xpu/__xpu__multi_encoder_compute.cc | 45 +++++++ 3 files changed, 117 insertions(+), 41 deletions(-) diff --git a/lite/core/optimizer/mir/fusion/__xpu__multi_encoder_fuse_pass.cc b/lite/core/optimizer/mir/fusion/__xpu__multi_encoder_fuse_pass.cc index 13f972a4530..ad0d526ea80 100644 --- a/lite/core/optimizer/mir/fusion/__xpu__multi_encoder_fuse_pass.cc +++ b/lite/core/optimizer/mir/fusion/__xpu__multi_encoder_fuse_pass.cc @@ -48,18 +48,22 @@ class XPUSingleEncoderFuser : public FuseBase { const std::string& input_pos = "Y", const std::string& qkv_ln_2_out_pos = "Y", const std::string& matmul_type = "matmul", + const std::string& matmul2_type = "matmul_v2", const std::string& mul_type = "mul", bool with_q_scale = true, bool norm_before = false, - const std::string& relative_type = "") + const std::string& relative_type = "", + bool with_mask = true) : act_type_(act_type), input_pos_(input_pos), qkv_ln_2_out_pos_(qkv_ln_2_out_pos), matmul_type_(matmul_type), + matmul2_type_(matmul2_type), mul_type_(mul_type), with_q_scale_(with_q_scale), norm_before_(norm_before), - relative_emb_type_(relative_type) {} + relative_emb_type_(relative_type), + with_mask_(with_mask) {} void BuildPattern() override { auto* input = VarNode("input") @@ -213,18 +217,25 @@ class XPUSingleEncoderFuser : public FuseBase { ->AsIntermediate(); auto* qk_matmul = OpNode("qk_matmul", matmul_type_)->AsIntermediate(); + std::string op_after_qk_matmul = with_mask_ ? "elementwise_add" : "softmax"; auto* qk_matmul_out = VarNode("qk_matmul_out") ->assert_is_op_output(matmul_type_, "Out") - ->assert_is_op_input("elementwise_add", "X") + ->assert_is_op_input(op_after_qk_matmul, "X") ->AsIntermediate(); - auto* qk_mask = VarNode("qk_mask") - ->assert_is_op_input("elementwise_add", "Y") - ->AsInput(); - auto* qk_add = OpNode("qk_add", "elementwise_add")->AsIntermediate(); - auto* qk_add_out = VarNode("qk_add_out") - ->assert_is_op_output("elementwise_add", "Out") - ->assert_is_op_input("softmax", "X") - ->AsIntermediate(); + PMNode* qk_mask = nullptr; + PMNode* qk_add = nullptr; + PMNode* qk_add_out = nullptr; + if (with_mask_) { + qk_mask = VarNode("qk_mask") + ->assert_is_op_input("elementwise_add", "Y") + ->AsInput(); + qk_add = OpNode("qk_add", "elementwise_add")->AsIntermediate(); + qk_add_out = VarNode("qk_add_out") + ->assert_is_op_output("elementwise_add", "Out") + ->assert_is_op_input("softmax", "X") + ->AsIntermediate(); + } + auto* qk_softmax = OpNode("qk_softmax", "softmax")->AsIntermediate(); auto* qk_softmax_out = VarNode("qk_softmax_out") ->assert_is_op_output("softmax", "Out") @@ -256,16 +267,16 @@ class XPUSingleEncoderFuser : public FuseBase { auto* v_transpose2 = OpNode("v_transpose2", "transpose2")->AsIntermediate(); auto* v_transpose2_out = VarNode("v_transpose2_out") ->assert_is_op_output("transpose2", "Out") - ->assert_is_op_input(matmul_type_, "Y") + ->assert_is_op_input(matmul2_type_, "Y") ->AsIntermediate(); auto* v_transpose2_xshape = VarNode("v_transpose2_xshape") ->assert_is_op_output("transpose2", "XShape") ->AsIntermediate(); - auto* qkv_matmul = OpNode("qkv_matmul", matmul_type_)->AsIntermediate(); + auto* qkv_matmul = OpNode("qkv_matmul", matmul2_type_)->AsIntermediate(); auto* qkv_matmul_out = VarNode("qkv_matmul_out") - ->assert_is_op_output(matmul_type_, "Out") + ->assert_is_op_output(matmul2_type_, "Out") ->assert_is_op_input("transpose2", "X") ->AsIntermediate(); auto* qkv_transpose2 = @@ -459,9 +470,14 @@ class XPUSingleEncoderFuser : public FuseBase { *k_reshape2 >> *k_reshape2_xshape; *k_transpose2 >> *k_transpose2_xshape; - *qk_matmul >> *qk_matmul_out >> *qk_add >> *qk_add_out >> *qk_softmax >> - *qk_softmax_out >> *qkv_matmul; - *qk_mask >> *qk_add; + if (with_mask_) { + *qk_matmul >> *qk_matmul_out >> *qk_add >> *qk_add_out >> *qk_softmax >> + *qk_softmax_out >> *qkv_matmul; + *qk_mask >> *qk_add; + } else { + *qk_matmul >> *qk_matmul_out >> *qk_softmax >> *qk_softmax_out >> + *qkv_matmul; + } if (norm_before_) { *ln_before_out >> *v_mul; @@ -513,7 +529,9 @@ class XPUSingleEncoderFuser : public FuseBase { cpp::OpDesc op_desc; op_desc.SetType("single_encoder"); op_desc.SetInput("Inputs", {matched.at("input")->arg()->name}); - op_desc.SetInput("Mask", {matched.at("qk_mask")->arg()->name}); + if (with_mask_) { + op_desc.SetInput("Mask", {matched.at("qk_mask")->arg()->name}); + } op_desc.SetInput("FCWeight", { matched.at("q_mul_y")->arg()->name, @@ -645,7 +663,6 @@ class XPUSingleEncoderFuser : public FuseBase { single_encoder_stmt->SetOp(fake_subgraph_op); std::vector froms = { - "qk_mask", "k_mul_y", "v_mul_y", "qkv_mul_y", @@ -660,6 +677,9 @@ class XPUSingleEncoderFuser : public FuseBase { "qkv_ln_2_scale", "qkv_ln_2_bias", }; + if (with_mask_) { + froms.push_back("qk_mask"); + } if (relative_emb_type_ == "__xpu__roformer_relative_embedding") { froms.push_back("q_cos_embedding"); froms.push_back("q_sin_embedding"); @@ -687,10 +707,12 @@ class XPUSingleEncoderFuser : public FuseBase { std::string input_pos_; std::string qkv_ln_2_out_pos_; std::string matmul_type_; + std::string matmul2_type_; std::string mul_type_; bool with_q_scale_; bool norm_before_; const std::string relative_emb_type_; + bool with_mask_; // quant_info: mul input_max, output_max * 6 + matmul x_max:y_max, output_max // * 2 void set_quant_info(Scope* scope, @@ -955,7 +977,7 @@ class XPUMultiEncoderFuser { std::string mask_name; for (auto* encoder : all_encoders) { auto* op_info = encoder->stmt()->op_info(); - if (mask_name.empty()) { + if (mask_name.empty() && op_info->HasInput("Mask")) { mask_name = op_info->Input("Mask").front(); } else { // CHECK(mask_name == op_info->Input("Mask").front()); @@ -1026,13 +1048,11 @@ class XPUMultiEncoderFuser { if (all_encoders.size() == 1) { // take care of only one encoder in_name = op_info->Input("Inputs").front(); - mask_name = op_info->Input("Mask").front(); out_name = op_info->Output("Outputs").front(); } else if (i == 0) { // first encoder to_remove.insert(cur_out); in_name = op_info->Input("Inputs").front(); - mask_name = op_info->Input("Mask").front(); } else if (i == all_encoders.size() - 1) { // last encoder to_remove.insert(cur_encoder); @@ -1051,7 +1071,9 @@ class XPUMultiEncoderFuser { for (auto kv : arg_map) { op_desc.SetInput(kv.first, kv.second); } - op_desc.SetInput("Mask", {mask_name}); + if (!mask_name.empty()) { + op_desc.SetInput("Mask", {mask_name}); + } op_desc.SetOutput("Output", {out_name}); op_desc.SetAttr("xpu", 1); op_desc.SetAttr( @@ -1382,9 +1404,11 @@ class XPUMultiEncoderFusePass : public ProgramPass { std::vector input_poss{"X", "Y"}; std::vector qkv_ln_2_out_poss{"X", "Y"}; std::vector matmul_types{"matmul", "matmul_v2"}; + std::vector matmul2_types{"matmul", "matmul_v2"}; std::vector mul_types{"mul", "matmul", "matmul_v2"}; std::vector with_q_scales{true, false}; std::vector norm_befores{true, false}; + std::vector with_mask{true, false}; std::vector relative_embedding_type{ "", "__xpu__roformer_relative_embedding"}; @@ -1423,23 +1447,29 @@ class XPUMultiEncoderFusePass : public ProgramPass { for (auto& input_pos : input_poss) { for (auto& qkv_ln_2_out_pos : qkv_ln_2_out_poss) { for (auto& matmul_type : matmul_types) { - for (auto& mul_type : mul_types) { - for (auto with_q_scale : with_q_scales) { - for (auto norm_before : norm_befores) { - for (auto relative_type : relative_embedding_type) { - fusion::XPUSingleEncoderFuser single_encoder_fuser( - act_type, - input_pos, - qkv_ln_2_out_pos, - matmul_type, - mul_type, - with_q_scale, - norm_before, - relative_type); - single_encoder_fuser(graph.get()); - fusion::XPUMultiEncoderFuser multi_encoder_fuser( - fc_precision, adaptive_seqlen); - multi_encoder_fuser(graph.get()); + for (auto& matmul2_type : matmul2_types) { + for (auto& mul_type : mul_types) { + for (auto with_q_scale : with_q_scales) { + for (auto norm_before : norm_befores) { + for (auto relative_type : relative_embedding_type) { + for (auto mask : with_mask) { + fusion::XPUSingleEncoderFuser single_encoder_fuser( + act_type, + input_pos, + qkv_ln_2_out_pos, + matmul_type, + matmul2_type, + mul_type, + with_q_scale, + norm_before, + relative_type, + mask); + single_encoder_fuser(graph.get()); + fusion::XPUMultiEncoderFuser multi_encoder_fuser( + fc_precision, adaptive_seqlen); + multi_encoder_fuser(graph.get()); + } + } } } } diff --git a/lite/core/optimizer/mir/fusion/__xpu__multi_encoder_slice_link_fuse_pass.cc b/lite/core/optimizer/mir/fusion/__xpu__multi_encoder_slice_link_fuse_pass.cc index 2d009df752e..7037805c28f 100644 --- a/lite/core/optimizer/mir/fusion/__xpu__multi_encoder_slice_link_fuse_pass.cc +++ b/lite/core/optimizer/mir/fusion/__xpu__multi_encoder_slice_link_fuse_pass.cc @@ -57,7 +57,8 @@ class XPUMultiEncoderSliceLinkFuser : public FuseBase { layer_norm = OpNode("layer_norm", "layer_norm"); layer_norm_out = VarNode("layer_norm_out") ->assert_is_op_output("layer_norm", "Y") - ->assert_is_op_input("slice", "Input"); + ->assert_is_op_input("slice", "Input") + ->assert_only_one_output(); } else { xpu_encoder->assert_op_attr("norm_before", false); encoder_out->assert_is_op_input("slice", "Input")->AsIntermediate(); diff --git a/lite/kernels/xpu/__xpu__multi_encoder_compute.cc b/lite/kernels/xpu/__xpu__multi_encoder_compute.cc index 03ffad6e592..33c85eea757 100644 --- a/lite/kernels/xpu/__xpu__multi_encoder_compute.cc +++ b/lite/kernels/xpu/__xpu__multi_encoder_compute.cc @@ -255,6 +255,51 @@ void XPUMultiEncoderCompute::run_encoder(const T* in, T* out) { arg_ln_bias_, qkv_attn_param); CHECK_EQ(r, 0); + } else if (param.mask == nullptr) { + // When no mask input, like VIT, create LOD to act as vsl. + int batch = static_cast(param.input->dims()[0]); + int max_seqlen = static_cast(param.input->dims()[1]); + std::vector lod; + for (int i = 0; i < batch + 1; i++) { + lod.push_back(i * max_seqlen); + } + query_lod = {lod.data(), static_cast(lod.size()), nullptr}; + // No need to pad, no matter slice or not + int max_pad_seqlen = -1; + xdnn::QKVAttnParam qkv_attn_param(query_lod, /* lod */ + param.head_num, + param.size_per_head, + qkv_act, + slice_idx, + true /* qkv fusion */, + max_pad_seqlen, + param.hidden_dim, + param.norm_before, /*is_pre_norm*/ + param.per_channel); + qkv_attn_param.quant_type_.assign(quant_types_.begin(), quant_types_.end()); + if (relative_type_ == 1) { + qkv_attn_param.relative_type = relative_type_; + qkv_attn_param.max_pos_len = param.max_pos_len; + qkv_attn_param.relative_pos.assign(roformer_embedding_.begin(), + roformer_embedding_.end()); + } + qkv_attn_param.scale_of_hidden_units = param.ffn_hidden_dim_scale; + if (std::is_same::value) { + CHECK_GT(fc_input_max_.size(), 0); + } + // std::cout << "running xdnn::transformer_encoder" << std::endl; + int r = xdnn::transformer_encoder( + ctx.GetRawContext(), + in, + *(XPUMultiEncoderCompute::get_weight()), + out, + fc_input_max_, + fc_weight_max_, + arg_fc_bias_, + arg_ln_scale_, + arg_ln_bias_, + qkv_attn_param); + CHECK_EQ(r, 0); } else { // no vsl int batch = static_cast(param.input->dims()[0]);