From dfec3fa6f1d4b934f2eaf532c037f75bb07a595e Mon Sep 17 00:00:00 2001 From: shanliang1992 Date: Wed, 20 Apr 2022 21:27:46 +0800 Subject: [PATCH] [XPU] fix lstm bug in xpu (#8886) --- .../xpu/__xpu__dynamic_lstm_compute.cc | 162 ++++++++---------- 1 file changed, 67 insertions(+), 95 deletions(-) diff --git a/lite/kernels/xpu/__xpu__dynamic_lstm_compute.cc b/lite/kernels/xpu/__xpu__dynamic_lstm_compute.cc index 3c840a49679..ac796d66326 100644 --- a/lite/kernels/xpu/__xpu__dynamic_lstm_compute.cc +++ b/lite/kernels/xpu/__xpu__dynamic_lstm_compute.cc @@ -39,21 +39,21 @@ void XPUDynamicLstmCompute::PrepareForRun() { weight_0_dims[0], weight_0_dims[1]); - // change weight_0 from [w_ix, w_gx, w_fx, w_ox] to [w_ix, w_fx, w_gx, w_ox] + // change weight_0 from [w_gx, w_ix, w_fx, w_ox] to [w_ix, w_fx, w_gx, w_ox] transpose_weight_0_ = TargetWrapperXPU::MallocScratchPad(weight_0_size * sizeof(float)); float* transpose_weight_0_addr = reinterpret_cast(transpose_weight_0_->addr_); XPU_CALL(xpu_memcpy(transpose_weight_0_addr, - cpu_transpose_weight_0.data(), - weight_0_size / 4 * sizeof(float), - XPUMemcpyKind::XPU_HOST_TO_DEVICE)); - XPU_CALL(xpu_memcpy(transpose_weight_0_addr + weight_0_size / 2, cpu_transpose_weight_0.data() + weight_0_size / 4, weight_0_size / 4 * sizeof(float), XPUMemcpyKind::XPU_HOST_TO_DEVICE)); XPU_CALL(xpu_memcpy(transpose_weight_0_addr + weight_0_size / 4, cpu_transpose_weight_0.data() + weight_0_size / 2, + weight_0_size / 4 * sizeof(float), + XPUMemcpyKind::XPU_HOST_TO_DEVICE)); + XPU_CALL(xpu_memcpy(transpose_weight_0_addr + weight_0_size / 2, + cpu_transpose_weight_0.data(), weight_0_size / 2 * sizeof(float), XPUMemcpyKind::XPU_HOST_TO_DEVICE)); XPU_CALL(xpu_memcpy(transpose_weight_0_addr + weight_0_size / 4 * 3, @@ -70,21 +70,21 @@ void XPUDynamicLstmCompute::PrepareForRun() { cpu_transpose_weight_1.data(), weight_1_dims[0], weight_1_dims[1]); - // change weight_1 from [w_ih, w_gh, w_fh, w_oh] to [w_ih, w_fh, w_gh, w_oh] + // change weight_1 from [w_gh, w_ih, w_fh, w_oh] to [w_ih, w_fh, w_gh, w_oh] transpose_weight_1_ = TargetWrapperXPU::MallocScratchPad(weight_1_size * sizeof(float)); float* transpose_weight_1_addr = reinterpret_cast(transpose_weight_1_->addr_); XPU_CALL(xpu_memcpy(transpose_weight_1_addr, - cpu_transpose_weight_1.data(), - weight_1_size / 4 * sizeof(float), - XPUMemcpyKind::XPU_HOST_TO_DEVICE)); - XPU_CALL(xpu_memcpy(transpose_weight_1_addr + weight_1_size / 2, cpu_transpose_weight_1.data() + weight_1_size / 4, weight_1_size / 4 * sizeof(float), XPUMemcpyKind::XPU_HOST_TO_DEVICE)); XPU_CALL(xpu_memcpy(transpose_weight_1_addr + weight_1_size / 4, cpu_transpose_weight_1.data() + weight_1_size / 2, + weight_1_size / 4 * sizeof(float), + XPUMemcpyKind::XPU_HOST_TO_DEVICE)); + XPU_CALL(xpu_memcpy(transpose_weight_1_addr + weight_1_size / 2, + cpu_transpose_weight_1.data(), weight_1_size / 2 * sizeof(float), XPUMemcpyKind::XPU_HOST_TO_DEVICE)); XPU_CALL(xpu_memcpy(transpose_weight_1_addr + weight_1_size / 4 * 3, @@ -92,21 +92,21 @@ void XPUDynamicLstmCompute::PrepareForRun() { weight_1_size / 4 * sizeof(float), XPUMemcpyKind::XPU_HOST_TO_DEVICE)); - // change bias_0 from [b_ix, b_gx, b_fx, b_ox] to [b_ix, b_fx, b_gx, b_ox] + // change bias_0 from [b_gx, b_ix, b_fx, b_ox] to [b_ix, b_fx, b_gx, b_ox] const float* bias_0 = param.bias_0->template data(); int bias_0_size = param.bias_0->numel(); bias_0_ = TargetWrapperXPU::MallocScratchPad(bias_0_size * sizeof(float)); float* bias_0_addr = reinterpret_cast(bias_0_->addr_); XPU_CALL(xpu_memcpy(bias_0_addr, - bias_0, - bias_0_size / 4 * sizeof(float), - XPUMemcpyKind::XPU_HOST_TO_DEVICE)); - XPU_CALL(xpu_memcpy(bias_0_addr + bias_0_size / 2, bias_0 + bias_0_size / 4, bias_0_size / 4 * sizeof(float), XPUMemcpyKind::XPU_HOST_TO_DEVICE)); XPU_CALL(xpu_memcpy(bias_0_addr + bias_0_size / 4, bias_0 + bias_0_size / 2, + bias_0_size / 4 * sizeof(float), + XPUMemcpyKind::XPU_HOST_TO_DEVICE)); + XPU_CALL(xpu_memcpy(bias_0_addr + bias_0_size / 2, + bias_0, bias_0_size / 2 * sizeof(float), XPUMemcpyKind::XPU_HOST_TO_DEVICE)); XPU_CALL(xpu_memcpy(bias_0_addr + bias_0_size / 4 * 3, @@ -114,13 +114,13 @@ void XPUDynamicLstmCompute::PrepareForRun() { bias_0_size / 4 * sizeof(float), XPUMemcpyKind::XPU_HOST_TO_DEVICE)); - // change bias_1 from [b_ix, b_gx, b_fx, b_ox] to [b_ix, b_fx, b_gx, b_ox] + // change bias_1 from [b_gx, b_ix, b_fx, b_ox] to [b_ix, b_fx, b_gx, b_ox] const float* bias_1 = param.bias_1->template data(); int bias_1_size = param.bias_1->numel(); bias_1_ = TargetWrapperXPU::MallocScratchPad(bias_1_size * sizeof(float)); float* bias_1_addr = reinterpret_cast(bias_1_->addr_); XPU_CALL(xpu_memcpy(bias_1_addr, - bias_1, + bias_1 + bias_1_size / 4, bias_1_size / 4 * sizeof(float), XPUMemcpyKind::XPU_HOST_TO_DEVICE)); XPU_CALL(xpu_memcpy(bias_1_addr + bias_1_size / 4, @@ -128,7 +128,7 @@ void XPUDynamicLstmCompute::PrepareForRun() { bias_1_size / 4 * sizeof(float), XPUMemcpyKind::XPU_HOST_TO_DEVICE)); XPU_CALL(xpu_memcpy(bias_1_addr + bias_1_size / 2, - bias_1 + bias_1_size / 4, + bias_1, bias_1_size / 2 * sizeof(float), XPUMemcpyKind::XPU_HOST_TO_DEVICE)); XPU_CALL(xpu_memcpy(bias_1_addr + bias_1_size / 4 * 3, @@ -185,62 +185,38 @@ void XPUDynamicLstmCompute::Run() { float* in_batch_tensor_addr = reinterpret_cast(in_batch_tensor->addr_); - // prepare lod and reverse lod - auto xpu_lod = - TargetWrapperXPU::MallocScratchPad(int_lod.size() * sizeof(int)); - int* xpu_lod_addr = reinterpret_cast(xpu_lod->addr_); - XPU_CALL(xpu_memcpy(xpu_lod_addr, - int_lod.data(), - int_lod.size() * sizeof(int), - XPUMemcpyKind::XPU_HOST_TO_DEVICE)); - std::vector reverse_int_lod(int_lod.size()); - auto xpu_reverse_int_lod = - TargetWrapperXPU::MallocScratchPad(reverse_int_lod.size() * sizeof(int)); - int* xpu_reverse_int_lod_addr = - reinterpret_cast(xpu_reverse_int_lod->addr_); - // reverse input if is_reverse = true if (is_reverse) { auto reverse_input = TargetWrapperXPU::MallocScratchPad( param.input->numel() * sizeof(float)); float* reverse_input_addr = reinterpret_cast(reverse_input->addr_); - int r = xdnn::sequence_reverse(ctx.GetRawContext(), - input_addr, - xpu_lod_addr, - reverse_input_addr, - batch_size, - xdim); + int r = xdnn::sequence_reverse( + ctx.GetRawContext(), + input_addr, + reverse_input_addr, + {int_lod.data(), static_cast(int_lod.size()), nullptr}, + xdim); CHECK_EQ(r, 0); - std::reverse(seq_len_tensor.begin(), seq_len_tensor.end()); - - // get reverse lod tensor - reverse_int_lod[0] = 0; - for (int i = 0; i < seq_len_tensor.size(); i++) { - reverse_int_lod[i + 1] = reverse_int_lod[i] + seq_len_tensor[i]; - } - XPU_CALL(xpu_memcpy(xpu_reverse_int_lod_addr, - reverse_int_lod.data(), - reverse_int_lod.size() * sizeof(int), - XPUMemcpyKind::XPU_HOST_TO_DEVICE)); - - r = xdnn::sequence_pad(ctx.GetRawContext(), - reverse_input_addr, - xpu_reverse_int_lod_addr, - in_batch_tensor_addr, - batch_size, - max_seq_len, - xdim, - 0); + r = xdnn::sequence_pad( + ctx.GetRawContext(), + reverse_input_addr, + in_batch_tensor_addr, + {int_lod.data(), static_cast(int_lod.size()), nullptr}, + batch_size, + max_seq_len, + xdim, + 0); CHECK_EQ(r, 0); } else { - int r = xdnn::sequence_pad(ctx.GetRawContext(), - input_addr, - xpu_lod_addr, - in_batch_tensor_addr, - batch_size, - max_seq_len, - xdim, - 0); + int r = xdnn::sequence_pad( + ctx.GetRawContext(), + input_addr, + in_batch_tensor_addr, + {int_lod.data(), static_cast(int_lod.size()), nullptr}, + batch_size, + max_seq_len, + xdim, + 0); CHECK_EQ(r, 0); } @@ -289,29 +265,28 @@ void XPUDynamicLstmCompute::Run() { const float* weight_0_maxptr = reinterpret_cast(weight_0_max_->addr_); const float* weight_1_maxptr = reinterpret_cast(weight_1_max_->addr_); - r = xdnn::lstm_train_for_old_paddle( - ctx.GetRawContext(), - transpose_in_addr, - h0, - c0, - transpose_weight_0_addr, - transpose_weight_1_addr, - bias_0_addr, - bias_1_addr, - transpose_out_addr, - last_h_addr, - last_c_addr, - batch_size, - xdim, - hdim, - max_seq_len, - seq_len_tensor, - nullptr, - nullptr, - weight_0_maxptr, - weight_1_maxptr, - i_f_g_o_addr, - c_addr); + r = xdnn::lstm_train(ctx.GetRawContext(), + transpose_in_addr, + h0, + c0, + transpose_weight_0_addr, + transpose_weight_1_addr, + bias_0_addr, + bias_1_addr, + transpose_out_addr, + last_h_addr, + last_c_addr, + batch_size, + xdim, + hdim, + max_seq_len, + seq_len_tensor, + nullptr, + nullptr, + weight_0_maxptr, + weight_1_maxptr, + i_f_g_o_addr, + c_addr); CHECK_EQ(r, 0); // transpose from transpose_out[seq_len, batch_size, hdim] to @@ -337,9 +312,7 @@ void XPUDynamicLstmCompute::Run() { ctx.GetRawContext(), out_batch_tensor_addr, reverse_output_addr, - {reverse_int_lod.data(), - static_cast(reverse_int_lod.size()), - xpu_reverse_int_lod_addr}, + {int_lod.data(), static_cast(int_lod.size()), nullptr}, max_seq_len, hdim); CHECK_EQ(r, 0); @@ -347,9 +320,8 @@ void XPUDynamicLstmCompute::Run() { r = xdnn::sequence_reverse( ctx.GetRawContext(), reverse_output_addr, - xpu_reverse_int_lod_addr, param.hidden->template mutable_data(TARGET(kXPU)), - batch_size, + {int_lod.data(), static_cast(int_lod.size()), nullptr}, hdim); CHECK_EQ(r, 0); } else { @@ -357,7 +329,7 @@ void XPUDynamicLstmCompute::Run() { ctx.GetRawContext(), out_batch_tensor_addr, param.hidden->template mutable_data(TARGET(kXPU)), - {int_lod.data(), static_cast(int_lod.size()), xpu_lod_addr}, + {int_lod.data(), static_cast(int_lod.size()), nullptr}, max_seq_len, hdim); CHECK_EQ(r, 0);