Skip to content

Commit

Permalink
Add support bias is none for fused_attention op. (#37411)
Browse files Browse the repository at this point in the history
Add support for bias is none for fused_attention op.
  • Loading branch information
limin2021 authored Nov 23, 2021
1 parent 4812eda commit 1a8786c
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 71 deletions.
77 changes: 47 additions & 30 deletions paddle/fluid/operators/fused/fused_attention_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,8 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias",
"FusedAttentionOp");

if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
OP_INOUT_CHECK(ctx->HasOutput("LnMean"), "Output", "LnMean",
Expand All @@ -54,8 +50,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
// qkv_out: [batch_size, seq_len, 3, num_head, dim_head]
OP_INOUT_CHECK(ctx->HasOutput("QKVOut"), "Output", "QKVOut",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("QKVBiasOut"), "Output", "QKVBiasOut",
"FusedAttentionOp");
if (ctx->HasInput("QKVBias")) {
OP_INOUT_CHECK(ctx->HasOutput("QKVBiasOut"), "Output", "QKVBiasOut",
"FusedAttentionOp");
}
OP_INOUT_CHECK(ctx->HasOutput("TransposeOut2"), "Output", "TransposeOut2",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("QKOut"), "Output", "QKOut",
Expand Down Expand Up @@ -107,6 +105,13 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"input qkv_weight = [%s]",
x_dim, y_dim));

PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2], y_dim[3],
platform::errors::InvalidArgument(
"The dimensions of qkv_weight must be 4"
"(3, num_head, dim_head, dim_embed),"
"and must satisfy the limitations: "
"(num_head * dim_head == dim_embed)"));

if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]});
Expand All @@ -119,8 +124,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
// [batch_size, seq_len, 3, num_head, head_size]
ctx->SetOutputDim("QKVOut",
{x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]});
ctx->SetOutputDim("QKVBiasOut",
{x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]});

if (ctx->HasInput("QKVBias")) {
ctx->SetOutputDim("QKVBiasOut",
{x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]});
}
// [3, batch_size, num_head, seq_len, head_size]
ctx->SetOutputDim("TransposeOut2",
{y_dim[0], x_dim[0], y_dim[1], x_dim[1], y_dim[2]});
Expand Down Expand Up @@ -173,11 +181,11 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
"H. Here, H represents the last dimension of its input tensor.")
.AsDispensable();
AddInput("QKVW", "The qkv weight tensor.");
AddInput("QKVBias", "The qkv bias tensor.");
AddInput("QKVBias", "The qkv bias tensor.").AsDispensable();
AddInput("SrcMask", "(optional) The attention mask tensor in fmha.")
.AsDispensable();
AddInput("OutLinearW", "The out_linear weight tensor.");
AddInput("OutLinearBias", "The out_linear bias tensor.");
AddInput("OutLinearBias", "The out_linear bias tensor.").AsDispensable();
AddInput("Ln2Scale",
"(optional) Scale is a 1-dimensional tensor of size "
"H. Here, H represents the last dimension of its input tensor.")
Expand Down Expand Up @@ -379,12 +387,8 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias",
"FusedAttentionGrad");

if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
if (ctx->HasOutput(framework::GradVarName("LnScale"))) {
Expand All @@ -399,14 +403,17 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}

ctx->SetOutputDim(framework::GradVarName("OutLinearBias"),
ctx->GetInputDim("OutLinearBias"));
if (ctx->HasOutput(framework::GradVarName("OutLinearBias"))) {
ctx->SetOutputDim(framework::GradVarName("OutLinearBias"),
ctx->GetInputDim("OutLinearBias"));
}
ctx->SetOutputDim(framework::GradVarName("OutLinearW"),
ctx->GetInputDim("OutLinearW"));
ctx->SetOutputDim(framework::GradVarName("QKVW"), ctx->GetInputDim("QKVW"));
ctx->SetOutputDim(framework::GradVarName("QKVBias"),
ctx->GetInputDim("QKVBias"));
if (ctx->HasOutput(framework::GradVarName("QKVBias"))) {
ctx->SetOutputDim(framework::GradVarName("QKVBias"),
ctx->GetInputDim("QKVBias"));
}

if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
ctx->SetOutputDim(framework::GradVarName("LnOut"),
Expand Down Expand Up @@ -434,8 +441,10 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
}
ctx->SetOutputDim(framework::GradVarName("QKVOut"),
ctx->GetInputDim("QKVOut"));
ctx->SetOutputDim(framework::GradVarName("QKVBiasOut"),
ctx->GetInputDim("QKVBiasOut"));
if (ctx->HasOutput(framework::GradVarName("QKVBiasOut"))) {
ctx->SetOutputDim(framework::GradVarName("QKVBiasOut"),
ctx->GetInputDim("QKVBiasOut"));
}
ctx->SetOutputDim(framework::GradVarName("OutLinearOut"),
ctx->GetInputDim("OutLinearOut"));
}
Expand All @@ -462,7 +471,15 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
// inputs x, parameters and their grad.
op->SetInput("X", this->Input("X"));
op->SetInput("QKVW", this->Input("QKVW"));
op->SetInput("QKVBias", this->Input("QKVBias"));

if (this->HasInput("QKVBias")) {
op->SetInput("QKVBias", this->Input("QKVBias"));
op->SetOutput(framework::GradVarName("QKVBias"),
this->InputGrad("QKVBias"));
op->SetInput("QKVBiasOut", this->Output("QKVBiasOut"));
op->SetOutput(framework::GradVarName("QKVBiasOut"),
this->OutputGrad("QKVBiasOut"));
}

if (this->HasInput("SrcMask")) {
op->SetInput("SrcMask", this->Input("SrcMask"));
Expand All @@ -472,7 +489,11 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
}

op->SetInput("OutLinearW", this->Input("OutLinearW"));
op->SetInput("OutLinearBias", this->Input("OutLinearBias"));
if (this->HasInput("OutLinearBias")) {
op->SetInput("OutLinearBias", this->Input("OutLinearBias"));
op->SetOutput(framework::GradVarName("OutLinearBias"),
this->InputGrad("OutLinearBias"));
}

op->SetAttrMap(this->Attrs());
bool is_pre_layer_norm =
Expand Down Expand Up @@ -503,10 +524,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {

op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("QKVW"), this->InputGrad("QKVW"));
op->SetOutput(framework::GradVarName("QKVBias"),
this->InputGrad("QKVBias"));
op->SetOutput(framework::GradVarName("OutLinearBias"),
this->InputGrad("OutLinearBias"));

op->SetOutput(framework::GradVarName("OutLinearW"),
this->InputGrad("OutLinearW"));

Expand All @@ -528,7 +546,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
this->Output("BiasDropoutResidualOut"));
}
op->SetInput("QKVOut", this->Output("QKVOut"));
op->SetInput("QKVBiasOut", this->Output("QKVBiasOut"));

op->SetInput("TransposeOut2", this->Output("TransposeOut2"));
op->SetInput("QKOut", this->Output("QKOut"));
op->SetInput("QKTVOut", this->Output("QKTVOut"));
Expand All @@ -553,8 +571,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
}

op->SetOutput(framework::GradVarName("QKVOut"), this->OutputGrad("QKVOut"));
op->SetOutput(framework::GradVarName("QKVBiasOut"),
this->OutputGrad("QKVBiasOut"));

op->SetOutput(framework::GradVarName("QKTVOut"),
this->OutputGrad("QKTVOut"));
op->SetOutput(framework::GradVarName("TransposeOut2"),
Expand Down
107 changes: 76 additions & 31 deletions paddle/fluid/operators/fused/fused_attention_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,11 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {

auto *x_data = input_x->data<T>();
auto *qkv_weight_data = qkv_weight->data<T>();
auto *qkv_bias_data = qkv_bias->data<T>();
auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data<T>();
auto *qkv_out_data = qkv_out->mutable_data<T>(ctx.GetPlace());
auto *qkv_bias_out_data = qkv_bias_out->mutable_data<T>(ctx.GetPlace());
auto *qkv_bias_out_data =
(qkv_bias == nullptr) ? nullptr
: qkv_bias_out->mutable_data<T>(ctx.GetPlace());

// get data ptr for FMHA.
auto *transpose_out_2_data =
Expand All @@ -117,7 +119,8 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {

// get data ptr for out_linear.
auto *out_linear_weight_data = out_linear_weight->data<T>();
auto *out_linear_bias_data = out_linear_bias->data<T>();
auto *out_linear_bias_data =
(out_linear_bias == nullptr) ? nullptr : out_linear_bias->data<T>();
auto *out_linear_out_data = out_linear_out->mutable_data<T>(ctx.GetPlace());

// get data ptr for bias+dropout+residual+layernorm
Expand All @@ -139,9 +142,15 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {

auto layer_norm_compute = AttnLayerNorm<T>(ctx.cuda_device_context(),
epsilon, bsz_seq, dim_embed);

bool compute_bias = true;
if (qkv_bias == nullptr) {
compute_bias = false;
}
// (transA, transB, compute_bias) = (false, true, true)
auto qkv_compute = AttnMatMul<T>(ctx.cuda_device_context(), false, true,
bsz_seq, output_size, input_size, true);
auto qkv_compute =
AttnMatMul<T>(ctx.cuda_device_context(), false, true, bsz_seq,
output_size, input_size, compute_bias);

AttnDropoutParam attn_dropout_param(
is_test_1, dropout_implementation_1, attn_dropout_rate,
Expand Down Expand Up @@ -176,10 +185,17 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
qkv_compute.ComputeForward(qkv_weight, input_x, qkv_bias, qkv_out,
qkv_bias_out);
}
fmha_ref_compute.ComputeForward(*qkv_bias_out, src_mask, transpose_out_2,
qk_out, src_mask_out, softmax_out,
attn_dropout_mask_out, attn_dropout_out,
qktv_out, fmha_out);
if (qkv_bias == nullptr) {
fmha_ref_compute.ComputeForward(*qkv_out, src_mask, transpose_out_2,
qk_out, src_mask_out, softmax_out,
attn_dropout_mask_out, attn_dropout_out,
qktv_out, fmha_out);
} else {
fmha_ref_compute.ComputeForward(*qkv_bias_out, src_mask, transpose_out_2,
qk_out, src_mask_out, softmax_out,
attn_dropout_mask_out, attn_dropout_out,
qktv_out, fmha_out);
}

// fmha_out: [batch_size, seq_len, num_head, head_dim]
// weight: [embed_dim, embed_dim]
Expand Down Expand Up @@ -249,9 +265,10 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *out_linear_bias = ctx.Input<Tensor>("OutLinearBias");
auto *src_mask_data = (src_mask == nullptr ? nullptr : src_mask->data<T>());
auto *qkv_weight_data = qkv_weight->data<T>();
auto *qkv_bias_data = qkv_bias->data<T>();
auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data<T>();
auto *out_linear_weight_data = out_linear_weight->data<T>();
auto *out_linear_bias_data = out_linear_bias->data<T>();
auto *out_linear_bias_data =
(out_linear_bias == nullptr) ? nullptr : out_linear_bias->data<T>();

// fw output
auto *fmha_out = ctx.Input<Tensor>("FMHAOut");
Expand Down Expand Up @@ -299,8 +316,15 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *d_bias_dropout_residual_out =
ctx.Output<Tensor>(framework::GradVarName("BiasDropoutResidualOut"));
auto *d_x_data = d_x->mutable_data<T>(ctx.GetPlace());
auto *d_qkv_out_data = d_qkv_out->mutable_data<T>(ctx.GetPlace());
auto *d_qkv_bias_out_data = d_qkv_bias_out->mutable_data<T>(ctx.GetPlace());
// when qkv_bias is not nullptr, d_qkv_out is equals to d_qkv_bias_out, the
// space can be reused.
auto *d_qkv_out_data = (d_qkv_bias_out != nullptr)
? nullptr
: d_qkv_out->mutable_data<T>(ctx.GetPlace());
auto *d_qkv_bias_out_data =
(d_qkv_bias_out == nullptr)
? nullptr
: d_qkv_bias_out->mutable_data<T>(ctx.GetPlace());
auto *d_qktv_out_data = d_qktv_out->mutable_data<T>(ctx.GetPlace());
auto *d_transpose_out_2_data =
d_transpose_out_2->mutable_data<T>(ctx.GetPlace());
Expand All @@ -326,11 +350,15 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *d_ln_2_bias = ctx.Output<Tensor>(framework::GradVarName("Ln2Bias"));

auto *d_qkv_weight_data = d_qkv_weight->mutable_data<T>(ctx.GetPlace());
auto *d_qkv_bias_data = d_qkv_bias->mutable_data<T>(ctx.GetPlace());
auto *d_qkv_bias_data = (d_qkv_bias == nullptr)
? nullptr
: d_qkv_bias->mutable_data<T>(ctx.GetPlace());
auto *d_out_linear_weight_data =
d_out_linear_weight->mutable_data<T>(ctx.GetPlace());
auto *d_out_linear_bias_data =
d_out_linear_bias->mutable_data<T>(ctx.GetPlace());
(d_out_linear_bias == nullptr)
? nullptr
: d_out_linear_bias->mutable_data<T>(ctx.GetPlace());

const auto input_x_dims = input_x->dims();
const auto qkv_w_dims = qkv_weight->dims();
Expand All @@ -352,12 +380,15 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {

bool transA = false;
bool transB = true;
bool compute_bias = true;
bool compute_qkv_bias = true;
if (qkv_bias == nullptr) {
compute_qkv_bias = false;
}
auto layer_norm_compute = AttnLayerNorm<T>(ctx.cuda_device_context(),
epsilon, bsz_seq, dim_embed);
auto qkv_compute =
AttnMatMul<T>(ctx.cuda_device_context(), transA, transB, bsz_seq,
output_size, input_size, compute_bias);
output_size, input_size, compute_qkv_bias);
AttnDropoutParam attn_dropout_param(
is_test_1, dropout_implementation_1, attn_dropout_prob,
is_upscale_in_train_1, is_fix_seed_1, seed_val_1, seed_1);
Expand All @@ -367,7 +398,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
output_size = hidden_size;
transA = false;
transB = false;
compute_bias = false;
bool compute_bias = false;
auto out_linear_compute =
AttnMatMul<T>(ctx.cuda_device_context(), transA, transB, bsz_seq,
output_size, input_size, compute_bias);
Expand Down Expand Up @@ -405,14 +436,19 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
d_out_linear_out, d_fmha_out,
d_out_linear_weight, nullptr);

fmha_ref_compute.ComputeBackward(
*transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out,
*attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out,
d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out,
d_transpose_out_2, nullptr, d_qkv_bias_out);
cudaMemcpyAsync(d_qkv_out_data, d_qkv_bias_out_data,
bsz_seq * 3 * num_head * dim_head * sizeof(T),
cudaMemcpyDeviceToDevice);
if (qkv_bias != nullptr) {
fmha_ref_compute.ComputeBackward(
*transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out,
*attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out,
d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out,
d_transpose_out_2, nullptr, d_qkv_bias_out);
} else {
fmha_ref_compute.ComputeBackward(
*transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out,
*attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out,
d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out,
d_transpose_out_2, nullptr, d_qkv_out);
}

if (pre_layer_norm) {
auto *ln_mean = ctx.Input<Tensor>("LnMean");
Expand All @@ -432,15 +468,24 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *d_ln_bias_data =
(d_ln_bias == nullptr ? nullptr
: d_ln_bias->mutable_data<U>(ctx.GetPlace()));

qkv_compute.ComputeBackward(ln_out, qkv_weight, d_qkv_bias_out, d_ln_out,
d_qkv_weight, d_qkv_bias);
if (qkv_bias != nullptr) {
qkv_compute.ComputeBackward(ln_out, qkv_weight, d_qkv_bias_out,
d_ln_out, d_qkv_weight, d_qkv_bias);
} else {
qkv_compute.ComputeBackward(ln_out, qkv_weight, d_qkv_out, d_ln_out,
d_qkv_weight, d_qkv_bias);
}
layer_norm_compute.ComputeBackward(x_data, d_ln_out_data, ln_scale_data,
ln_mean_data, ln_var_data, d_x_data,
d_ln_scale_data, d_ln_bias_data);
} else {
qkv_compute.ComputeBackward(input_x, qkv_weight, d_qkv_bias_out, d_x,
d_qkv_weight, d_qkv_bias);
if (qkv_bias != nullptr) {
qkv_compute.ComputeBackward(input_x, qkv_weight, d_qkv_bias_out, d_x,
d_qkv_weight, d_qkv_bias);
} else {
qkv_compute.ComputeBackward(input_x, qkv_weight, d_qkv_out, d_x,
d_qkv_weight, d_qkv_bias);
}
}
// gradient accumulation
std::vector<const Tensor *> ins;
Expand Down
Loading

0 comments on commit 1a8786c

Please sign in to comment.