Skip to content

Commit

Permalink
[XPU] fix lstm bug in xpu (#8886)
Browse files Browse the repository at this point in the history
  • Loading branch information
shanliang1992 authored Apr 20, 2022
1 parent caaf1cf commit dfec3fa
Showing 1 changed file with 67 additions and 95 deletions.
162 changes: 67 additions & 95 deletions lite/kernels/xpu/__xpu__dynamic_lstm_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,21 @@ void XPUDynamicLstmCompute::PrepareForRun() {
weight_0_dims[0],
weight_0_dims[1]);

// change weight_0 from [w_ix, w_gx, w_fx, w_ox] to [w_ix, w_fx, w_gx, w_ox]
// change weight_0 from [w_gx, w_ix, w_fx, w_ox] to [w_ix, w_fx, w_gx, w_ox]
transpose_weight_0_ =
TargetWrapperXPU::MallocScratchPad(weight_0_size * sizeof(float));
float* transpose_weight_0_addr =
reinterpret_cast<float*>(transpose_weight_0_->addr_);
XPU_CALL(xpu_memcpy(transpose_weight_0_addr,
cpu_transpose_weight_0.data(),
weight_0_size / 4 * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE));
XPU_CALL(xpu_memcpy(transpose_weight_0_addr + weight_0_size / 2,
cpu_transpose_weight_0.data() + weight_0_size / 4,
weight_0_size / 4 * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE));
XPU_CALL(xpu_memcpy(transpose_weight_0_addr + weight_0_size / 4,
cpu_transpose_weight_0.data() + weight_0_size / 2,
weight_0_size / 4 * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE));
XPU_CALL(xpu_memcpy(transpose_weight_0_addr + weight_0_size / 2,
cpu_transpose_weight_0.data(),
weight_0_size / 2 * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE));
XPU_CALL(xpu_memcpy(transpose_weight_0_addr + weight_0_size / 4 * 3,
Expand All @@ -70,65 +70,65 @@ void XPUDynamicLstmCompute::PrepareForRun() {
cpu_transpose_weight_1.data(),
weight_1_dims[0],
weight_1_dims[1]);
// change weight_1 from [w_ih, w_gh, w_fh, w_oh] to [w_ih, w_fh, w_gh, w_oh]
// change weight_1 from [w_gh, w_ih, w_fh, w_oh] to [w_ih, w_fh, w_gh, w_oh]
transpose_weight_1_ =
TargetWrapperXPU::MallocScratchPad(weight_1_size * sizeof(float));
float* transpose_weight_1_addr =
reinterpret_cast<float*>(transpose_weight_1_->addr_);
XPU_CALL(xpu_memcpy(transpose_weight_1_addr,
cpu_transpose_weight_1.data(),
weight_1_size / 4 * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE));
XPU_CALL(xpu_memcpy(transpose_weight_1_addr + weight_1_size / 2,
cpu_transpose_weight_1.data() + weight_1_size / 4,
weight_1_size / 4 * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE));
XPU_CALL(xpu_memcpy(transpose_weight_1_addr + weight_1_size / 4,
cpu_transpose_weight_1.data() + weight_1_size / 2,
weight_1_size / 4 * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE));
XPU_CALL(xpu_memcpy(transpose_weight_1_addr + weight_1_size / 2,
cpu_transpose_weight_1.data(),
weight_1_size / 2 * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE));
XPU_CALL(xpu_memcpy(transpose_weight_1_addr + weight_1_size / 4 * 3,
cpu_transpose_weight_1.data() + weight_1_size / 4 * 3,
weight_1_size / 4 * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE));

// change bias_0 from [b_ix, b_gx, b_fx, b_ox] to [b_ix, b_fx, b_gx, b_ox]
// change bias_0 from [b_gx, b_ix, b_fx, b_ox] to [b_ix, b_fx, b_gx, b_ox]
const float* bias_0 = param.bias_0->template data<float>();
int bias_0_size = param.bias_0->numel();
bias_0_ = TargetWrapperXPU::MallocScratchPad(bias_0_size * sizeof(float));
float* bias_0_addr = reinterpret_cast<float*>(bias_0_->addr_);
XPU_CALL(xpu_memcpy(bias_0_addr,
bias_0,
bias_0_size / 4 * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE));
XPU_CALL(xpu_memcpy(bias_0_addr + bias_0_size / 2,
bias_0 + bias_0_size / 4,
bias_0_size / 4 * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE));
XPU_CALL(xpu_memcpy(bias_0_addr + bias_0_size / 4,
bias_0 + bias_0_size / 2,
bias_0_size / 4 * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE));
XPU_CALL(xpu_memcpy(bias_0_addr + bias_0_size / 2,
bias_0,
bias_0_size / 2 * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE));
XPU_CALL(xpu_memcpy(bias_0_addr + bias_0_size / 4 * 3,
bias_0 + bias_0_size / 4 * 3,
bias_0_size / 4 * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE));

// change bias_1 from [b_ix, b_gx, b_fx, b_ox] to [b_ix, b_fx, b_gx, b_ox]
// change bias_1 from [b_gx, b_ix, b_fx, b_ox] to [b_ix, b_fx, b_gx, b_ox]
const float* bias_1 = param.bias_1->template data<float>();
int bias_1_size = param.bias_1->numel();
bias_1_ = TargetWrapperXPU::MallocScratchPad(bias_1_size * sizeof(float));
float* bias_1_addr = reinterpret_cast<float*>(bias_1_->addr_);
XPU_CALL(xpu_memcpy(bias_1_addr,
bias_1,
bias_1 + bias_1_size / 4,
bias_1_size / 4 * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE));
XPU_CALL(xpu_memcpy(bias_1_addr + bias_1_size / 4,
bias_1 + bias_1_size / 2,
bias_1_size / 4 * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE));
XPU_CALL(xpu_memcpy(bias_1_addr + bias_1_size / 2,
bias_1 + bias_1_size / 4,
bias_1,
bias_1_size / 2 * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE));
XPU_CALL(xpu_memcpy(bias_1_addr + bias_1_size / 4 * 3,
Expand Down Expand Up @@ -185,62 +185,38 @@ void XPUDynamicLstmCompute::Run() {
float* in_batch_tensor_addr =
reinterpret_cast<float*>(in_batch_tensor->addr_);

// prepare lod and reverse lod
auto xpu_lod =
TargetWrapperXPU::MallocScratchPad(int_lod.size() * sizeof(int));
int* xpu_lod_addr = reinterpret_cast<int*>(xpu_lod->addr_);
XPU_CALL(xpu_memcpy(xpu_lod_addr,
int_lod.data(),
int_lod.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE));
std::vector<int> reverse_int_lod(int_lod.size());
auto xpu_reverse_int_lod =
TargetWrapperXPU::MallocScratchPad(reverse_int_lod.size() * sizeof(int));
int* xpu_reverse_int_lod_addr =
reinterpret_cast<int*>(xpu_reverse_int_lod->addr_);

// reverse input if is_reverse = true
if (is_reverse) {
auto reverse_input = TargetWrapperXPU::MallocScratchPad(
param.input->numel() * sizeof(float));
float* reverse_input_addr = reinterpret_cast<float*>(reverse_input->addr_);
int r = xdnn::sequence_reverse<float, int>(ctx.GetRawContext(),
input_addr,
xpu_lod_addr,
reverse_input_addr,
batch_size,
xdim);
int r = xdnn::sequence_reverse<float, int>(
ctx.GetRawContext(),
input_addr,
reverse_input_addr,
{int_lod.data(), static_cast<int>(int_lod.size()), nullptr},
xdim);
CHECK_EQ(r, 0);
std::reverse(seq_len_tensor.begin(), seq_len_tensor.end());

// get reverse lod tensor
reverse_int_lod[0] = 0;
for (int i = 0; i < seq_len_tensor.size(); i++) {
reverse_int_lod[i + 1] = reverse_int_lod[i] + seq_len_tensor[i];
}
XPU_CALL(xpu_memcpy(xpu_reverse_int_lod_addr,
reverse_int_lod.data(),
reverse_int_lod.size() * sizeof(int),
XPUMemcpyKind::XPU_HOST_TO_DEVICE));

r = xdnn::sequence_pad<float, int>(ctx.GetRawContext(),
reverse_input_addr,
xpu_reverse_int_lod_addr,
in_batch_tensor_addr,
batch_size,
max_seq_len,
xdim,
0);
r = xdnn::sequence_pad<float, int>(
ctx.GetRawContext(),
reverse_input_addr,
in_batch_tensor_addr,
{int_lod.data(), static_cast<int>(int_lod.size()), nullptr},
batch_size,
max_seq_len,
xdim,
0);
CHECK_EQ(r, 0);
} else {
int r = xdnn::sequence_pad<float, int>(ctx.GetRawContext(),
input_addr,
xpu_lod_addr,
in_batch_tensor_addr,
batch_size,
max_seq_len,
xdim,
0);
int r = xdnn::sequence_pad<float, int>(
ctx.GetRawContext(),
input_addr,
in_batch_tensor_addr,
{int_lod.data(), static_cast<int>(int_lod.size()), nullptr},
batch_size,
max_seq_len,
xdim,
0);
CHECK_EQ(r, 0);
}

Expand Down Expand Up @@ -289,29 +265,28 @@ void XPUDynamicLstmCompute::Run() {
const float* weight_0_maxptr = reinterpret_cast<float*>(weight_0_max_->addr_);
const float* weight_1_maxptr = reinterpret_cast<float*>(weight_1_max_->addr_);

r = xdnn::lstm_train_for_old_paddle<float, float, int16_t>(
ctx.GetRawContext(),
transpose_in_addr,
h0,
c0,
transpose_weight_0_addr,
transpose_weight_1_addr,
bias_0_addr,
bias_1_addr,
transpose_out_addr,
last_h_addr,
last_c_addr,
batch_size,
xdim,
hdim,
max_seq_len,
seq_len_tensor,
nullptr,
nullptr,
weight_0_maxptr,
weight_1_maxptr,
i_f_g_o_addr,
c_addr);
r = xdnn::lstm_train<float, float, int16_t>(ctx.GetRawContext(),
transpose_in_addr,
h0,
c0,
transpose_weight_0_addr,
transpose_weight_1_addr,
bias_0_addr,
bias_1_addr,
transpose_out_addr,
last_h_addr,
last_c_addr,
batch_size,
xdim,
hdim,
max_seq_len,
seq_len_tensor,
nullptr,
nullptr,
weight_0_maxptr,
weight_1_maxptr,
i_f_g_o_addr,
c_addr);
CHECK_EQ(r, 0);

// transpose from transpose_out[seq_len, batch_size, hdim] to
Expand All @@ -337,27 +312,24 @@ void XPUDynamicLstmCompute::Run() {
ctx.GetRawContext(),
out_batch_tensor_addr,
reverse_output_addr,
{reverse_int_lod.data(),
static_cast<int>(reverse_int_lod.size()),
xpu_reverse_int_lod_addr},
{int_lod.data(), static_cast<int>(int_lod.size()), nullptr},
max_seq_len,
hdim);
CHECK_EQ(r, 0);

r = xdnn::sequence_reverse<float, int>(
ctx.GetRawContext(),
reverse_output_addr,
xpu_reverse_int_lod_addr,
param.hidden->template mutable_data<float>(TARGET(kXPU)),
batch_size,
{int_lod.data(), static_cast<int>(int_lod.size()), nullptr},
hdim);
CHECK_EQ(r, 0);
} else {
r = xdnn::sequence_unpad<float, int>(
ctx.GetRawContext(),
out_batch_tensor_addr,
param.hidden->template mutable_data<float>(TARGET(kXPU)),
{int_lod.data(), static_cast<int>(int_lod.size()), xpu_lod_addr},
{int_lod.data(), static_cast<int>(int_lod.size()), nullptr},
max_seq_len,
hdim);
CHECK_EQ(r, 0);
Expand Down

0 comments on commit dfec3fa

Please sign in to comment.