Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ARM] fix global_avg_pool and rnn fp16 overflow #9770

Merged
merged 1 commit into from
Nov 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions lite/backends/arm/math/fp16/pooling_fp16.cc
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,10 @@ void pooling_basic_fp16(POOLING_PARAM,
"1: \n" \
"fadd v4.8h, v0.8h, v2.8h\n" \
"fadd v5.8h, v1.8h, v3.8h\n" \
"fmul v4.8h, %[vsize].8h, v4.8h\n" \
"ldp q0, q1, [%[data_in_channel]], #32\n" \
"fadd %[vsum].8h, %[vsum].8h, v4.8h\n" \
"fmul v5.8h, %[vsize].8h, v5.8h\n" \
"ldp q2, q3, [%[data_in_channel]], #32\n" \
"subs %w[cnt], %w[cnt], #1 \n" \
"fadd %[vsum].8h, %[vsum].8h, v5.8h\n" \
Expand All @@ -324,6 +326,7 @@ void pooling_basic_fp16(POOLING_PARAM,
"blt 3f\n" \
"2: \n" \
"subs %w[remain], %w[remain], #1 \n" \
"fmul v0.8h, v0.8h, %[vsize].8h\n" \
"fadd %[vsum].8h, %[vsum].8h, v0.8h\n" \
"ld1 {v0.8h}, [%[data_in_channel]], #16\n" \
"bne 2b \n" \
Expand Down Expand Up @@ -572,8 +575,10 @@ void pooling_basic_fp16(POOLING_PARAM,
"vadd.f16 q4, q0, q2\n" \
"vadd.f16 q5, q1, q3\n" \
"vld1.16 {d0-d3}, [%[data_in_channel]]!\n" \
"vmul.f16 q4, q4, %q[vsize]\n" \
"vadd.f16 %q[vsum], %q[vsum], q4\n" \
"vld1.16 {d4-d7}, [%[data_in_channel]]!\n" \
"vmul.f16 q5, q5, %q[vsize]\n" \
"vadd.f16 %q[vsum], %q[vsum], q5\n" \
"subs %[cnt], %[cnt], #1\n" \
"bne 1b\n"
Expand All @@ -585,6 +590,7 @@ void pooling_basic_fp16(POOLING_PARAM,
"blt 3f\n" \
"2:\n" \
"subs %[remain], %[remain], #1\n" \
"vmul.f16 q0, q0, %q[vsize]\n" \
"vadd.f16 %q[vsum], %q[vsum], q0\n" \
"vld1.16 {d0, d1}, [%[data_in_channel]]!\n" \
"bne 2b \n" \
Expand Down Expand Up @@ -1316,11 +1322,13 @@ void pooling_global_max_fp16(POOLING_PARAM) {

void pooling_global_avg_fp16(POOLING_PARAM) {
int size_channel_in = win * hin;

int cnt = size_channel_in >> 5;
int remain = size_channel_in & 31;
int cnt_8 = remain >> 3;
int remain_8 = remain & 7;
float16_t size_channel_in_1 = 1.f / size_channel_in;
float16x8_t vec_size_channel = vdupq_n_f16(size_channel_in_1);

for (int n = 0; n < num; ++n) {
float16_t *data_out_batch = dout + n * chout;
const float16_t *data_in_batch = din + n * chin * size_channel_in;
Expand All @@ -1336,7 +1344,7 @@ void pooling_global_avg_fp16(POOLING_PARAM) {
[cnt] "+r"(size_cnt),
[remain] "+r"(size_remain),
[vsum] "+w"(vsum)
:
: [vsize] "w"(vec_size_channel)
#ifdef __aarch64__
: "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6");
#else
Expand All @@ -1346,11 +1354,12 @@ void pooling_global_avg_fp16(POOLING_PARAM) {
float16x4_t vsum_tmp = vadd_f16(vget_low_f16(vsum), vget_high_f16(vsum));
float16x4_t vtmp1 = vpadd_f16(vsum_tmp, vsum_tmp);
float16x4_t vtmp2 = vpadd_f16(vtmp1, vtmp1);
float16_t res = vtmp2[0];
for (int i = 0; i < remain_8; i++) {
vtmp2[0] += data_in_channel[0];
res += data_in_channel[0] / size_channel_in;
data_in_channel++;
}
data_out_batch[c] = vtmp2[0] / size_channel_in;
data_out_batch[c] = res;
}
LITE_PARALLEL_END()
}
Expand Down
9 changes: 6 additions & 3 deletions lite/backends/arm/math/sve/pooling_sve.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ void pooling_global_avg_fp16_sve(const float16_t* din,
int size_channel_in = win * hin;
auto data_out = static_cast<float16_t*>(dout);
auto data_in = static_cast<const float16_t*>(din);
float16_t size_channel_in_1 = 1.f / size_channel_in;
svfloat16_t vsize = svdup_n_f16(size_channel_in_1);

for (int n = 0; n < num; ++n) {
float16_t* data_out_batch = data_out + n * chout;
const float16_t* data_in_batch = data_in + n * chin * size_channel_in;
Expand All @@ -104,6 +107,7 @@ void pooling_global_avg_fp16_sve(const float16_t* din,
"add x0, x0, %x[cnth]\n"
"ld1h {z0.h}, p0/Z, [%x[data_in_channel]]\n"
"add %x[data_in_channel], %x[data_in_channel], %x[cntb]\n"
"fmul z0.h, p0/M, z0.h, %[vsize].h\n"
"fadd z1.h, p0/M, z1.h, z0.h\n"
"whilelt p0.h, x0, %x[size_channel_in]\n"
"b.any 1b\n"
Expand All @@ -114,14 +118,13 @@ void pooling_global_avg_fp16_sve(const float16_t* din,
: [size_channel_in] "r"(size_channel_in),
[cnth] "r"(cnth),
[cntb] "r"(cntb),
[data_out_channel] "r"(data_out_channel)
[data_out_channel] "r"(data_out_channel),
[vsize] "w"(vsize)
: "cc", "memory", "z0", "z1", "p0", "x0");
data_out_channel[0] = data_out_channel[0] / size_channel_in;
}
LITE_PARALLEL_END();
}
}

#endif

} // namespace math
Expand Down
21 changes: 0 additions & 21 deletions lite/kernels/arm/rnn_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1535,27 +1535,6 @@ void RnnCompute<PRECISION(kFP16)>::Run() {
} // namespace lite
} // namespace paddle

#ifdef ENABLE_ARM_FP16
using rnn_f16_compute =
paddle::lite::kernels::arm::RnnCompute<PRECISION(kFP16)>;
REGISTER_LITE_KERNEL(rnn, kARM, kFP16, kNCHW, rnn_f16_compute, fp16)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFP16))})
.BindInput("WeightList",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFP16))})
.BindInput("PreState",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFP16))})
.BindInput("SequenceLength",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("DropoutState",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFP16))})
.BindOutput("Reserve",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFP16))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFP16))})
.BindOutput("State",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFP16))})
.Finalize();
#endif // ENABLE_ARM_FP16

using rnn_f32_compute =
paddle::lite::kernels::arm::RnnCompute<PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(rnn, kARM, kFloat, kNCHW, rnn_f32_compute, def)
Expand Down
2 changes: 1 addition & 1 deletion lite/tests/math/pool_fp16_compute_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ TEST(TestPoolRand, test_pool_rand) {
}
#endif /// random param conv

#ifdef LITE_WITH_ARM8_SVE2 /// global_pool
#if 1 /// global_pool
TEST(TesPoolGlobal, test_pool_fp16_global) {
for (auto& h : {51})
test_pool_fp16({DDim({1, 64, h, h})},
Expand Down