Skip to content

Commit

Permalink
[XPU] support ffn intermediate size M!=4 (PaddlePaddle#9646)
Browse files Browse the repository at this point in the history
  • Loading branch information
newway authored and QShiX committed Nov 9, 2022
1 parent 8e0dfed commit d8dcac8
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 6 deletions.
20 changes: 16 additions & 4 deletions lite/core/optimizer/mir/fusion/__xpu__multi_encoder_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -340,9 +340,7 @@ class XPUSingleEncoderFuser : public FuseBase {
};
auto* qkv_mul_3_y =
VarNode("qkv_mul_3_y")->assert_is_op_input(mul_type_, "Y")->AsInput();
auto* qkv_mul_3 = OpNode("qkv_mul_3", mul_type_)
->assert_node_satisfied(qkv_weight_teller)
->AsIntermediate();
auto* qkv_mul_3 = OpNode("qkv_mul_3", mul_type_)->AsIntermediate();
auto* qkv_mul_3_out = VarNode("qkv_mul_3_out")
->assert_is_op_output(mul_type_, "Out")
->assert_is_op_input("elementwise_add", "X")
Expand Down Expand Up @@ -572,8 +570,17 @@ class XPUSingleEncoderFuser : public FuseBase {
auto* scope = matched.at("q_mul")->stmt()->op()->scope();
auto q_mul_y_shape = scope->FindMutableTensor(q_mul_input_y_name)->dims();
hidden_dim = q_mul_y_shape[0];
int scale_hidden_dim = 4;
{
auto* ffn0_mul_op_info = matched.at("qkv_mul_3")->stmt()->op_info();
auto ffn0_mul_y_name = ffn0_mul_op_info->Input("Y").front();
auto ffn0_mul_y_shape = scope->FindMutableTensor(ffn0_mul_y_name)->dims();
CHECK_EQ(ffn0_mul_y_shape.size(), 2);
scale_hidden_dim = ffn0_mul_y_shape[1] / ffn0_mul_y_shape[0];
}
VLOG(3) << "q mul Y shape: " << q_mul_y_shape
<< ", hidden_dim:" << hidden_dim;
<< ", hidden_dim:" << hidden_dim
<< ", ffn0 Y shape[1]/shape[0]:" << scale_hidden_dim;
auto* qkv_mul_op_info = matched.at("qkv_mul")->stmt()->op_info();
auto qkv_mul_input_y_name = qkv_mul_op_info->Input("Y").front();
auto qkv_mul_y_shape =
Expand Down Expand Up @@ -625,6 +632,8 @@ class XPUSingleEncoderFuser : public FuseBase {
} else {
op_desc.SetAttr<int>("relative_type", 0);
}
op_desc.SetAttr<int>("ffn_hidden_dim_scale", scale_hidden_dim);

auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph");
auto sub_program_desc = std::make_shared<cpp::ProgramDesc>();
sub_program_desc->AddBlock<cpp::BlockDesc>();
Expand Down Expand Up @@ -961,6 +970,8 @@ class XPUMultiEncoderFuser {
per_channel = first_encoder_op_info->GetAttr<bool>("per_channel");
}
const int hidden_dim = first_encoder_op_info->GetAttr<int>("hidden_dim");
const int scale_hidden_dim =
first_encoder_op_info->GetAttr<int>("ffn_hidden_dim_scale");
std::string in_name, out_name;
std::vector<std::string> arg_names{
"FCWeight", "FCBias", "LNScale", "LNBias"};
Expand Down Expand Up @@ -1073,6 +1084,7 @@ class XPUMultiEncoderFuser {
op_desc.SetAttr<int>("hidden_dim", hidden_dim);
op_desc.SetAttr<int>("head_num",
first_encoder_op_info->GetAttr<int>("head_num"));
op_desc.SetAttr<int>("ffn_hidden_dim_scale", scale_hidden_dim);
op_desc.SetAttr<int>(
"size_per_head",
first_encoder_op_info->GetAttr<int>("size_per_head"));
Expand Down
5 changes: 3 additions & 2 deletions lite/kernels/xpu/__xpu__multi_encoder_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ void XPUMultiEncoderCompute::run_encoder(const T* in, T* out) {
qkv_attn_param.relative_pos.assign(roformer_embedding_.begin(),
roformer_embedding_.end());
}

qkv_attn_param.scale_of_hidden_units = param.ffn_hidden_dim_scale;
if (std::is_same<TGEMM, int8_t>::value) {
CHECK_GT(fc_input_max_.size(), 0);
}
Expand All @@ -262,7 +262,8 @@ void XPUMultiEncoderCompute::run_encoder(const T* in, T* out) {
std::vector<int64_t> mask_shape = param.mask->dims().Vectorize();
std::vector<int> encoder_mask_shape =
std::vector<int>(mask_shape.begin(), mask_shape.end());

CHECK_EQ(param.ffn_hidden_dim_scale, 4)
<< "xpu don't support ffn_hidden_dim_scale!=4 when no vsl";
xdnn::QKVAttnParam qkv_attn_param(batch,
max_seqlen,
param.head_num,
Expand Down
1 change: 1 addition & 0 deletions lite/operators/__xpu__multi_encoder_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ bool XPUMultiEncoderOp::AttachImpl(const cpp::OpDesc& op_desc,
param_.n_layers = op_desc.GetAttr<int>("n_layers");
param_.hidden_dim = op_desc.GetAttr<int>("hidden_dim");
param_.head_num = op_desc.GetAttr<int>("head_num");
param_.ffn_hidden_dim_scale = op_desc.GetAttr<int>("ffn_hidden_dim_scale");
param_.size_per_head = op_desc.GetAttr<int>("size_per_head");
param_.act_type = op_desc.GetAttr<std::string>("act_type");
param_.precision = op_desc.GetAttr<std::string>("precision");
Expand Down
1 change: 1 addition & 0 deletions lite/operators/op_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -1740,6 +1740,7 @@ struct XPUMultiEncoderParam : ParamBase {
int head_num{};
int size_per_head{};
int hidden_dim{};
int ffn_hidden_dim_scale{4};
std::string act_type{};
int relative_type{0};
int max_pos_len{512}; // relative embedding [max_pos_len, head_dim]
Expand Down

0 comments on commit d8dcac8

Please sign in to comment.