diff --git a/lite/kernels/host/topk_v2_compute.cc b/lite/kernels/host/topk_v2_compute.cc index 896d5b863f5..05e34b89ba4 100644 --- a/lite/kernels/host/topk_v2_compute.cc +++ b/lite/kernels/host/topk_v2_compute.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -44,19 +44,20 @@ void TopkV2Compute::Run() { int inner_size = x_dims.count(axis + 1, dim_size); int sum_size = axis_size * inner_size; int out_sum_size = k * inner_size; - for (int n = 0; n < outer_size; n++) { - const float* in_data = x_data + n * sum_size; - float* out_data = out_val + n * out_sum_size; - int64_t* out_ind_data = out_ind + n * out_sum_size; - for (int i = 0; i < inner_size; i++) { - std::vector> vec; - for (int j = 0; j < axis_size; j++) { - vec.push_back(std::make_pair(in_data[j * inner_size + i], j)); - } - std::partial_sort(vec.begin(), vec.begin() + k, vec.end(), comp_func); - for (int j = 0; j < k; j++) { - out_data[j * inner_size + i] = vec[j].first; - out_ind_data[j * inner_size + i] = vec[j].second; + + for (int i = 0; i < outer_size; i++) { + int glb_in_off = i * sum_size; + int glb_out_off = i * out_sum_size; + std::vector> vec; + for (int j = 0; j < axis_size; j++) { + vec.push_back(std::make_pair(x_data[glb_in_off + j * inner_size], j)); + } + std::partial_sort(vec.begin(), vec.begin() + k, vec.end(), comp_func); + for (int j = 0; j < k; j++) { + for (int k = 0; k < inner_size; k++) { + int cur_off = glb_in_off + vec[j].second * inner_size + k; + out_val[glb_out_off + j * inner_size + k] = x_data[cur_off]; + out_ind[glb_out_off + j * inner_size + k] = vec[j].second; } } } diff --git a/lite/tests/kernels/CMakeLists.txt b/lite/tests/kernels/CMakeLists.txt index 332d458b6f1..fb481ee42b3 100644 --- a/lite/tests/kernels/CMakeLists.txt +++ b/lite/tests/kernels/CMakeLists.txt @@ -47,6 +47,8 @@ lite_cc_test(test_kernel_sequence_softmax_compute SRCS sequence_softmax_compute_ lite_cc_test(test_kernel_compare_compute SRCS compare_compute_test.cc DEPS ${test_kernel_deps}) lite_cc_test(test_kernel_logical_compute SRCS logical_compute_test.cc DEPS ${test_kernel_deps}) lite_cc_test(test_kernel_topk_compute SRCS topk_compute_test.cc DEPS ${test_kernel_deps}) +lite_cc_test(test_kernel_topk_v2_compute SRCS topk_v2_compute_test.cc DEPS ${test_kernel_deps}) + lite_cc_test(test_kernel_increment_compute SRCS increment_compute_test.cc DEPS ${test_kernel_deps}) lite_cc_test(test_kernel_write_to_array_compute SRCS write_to_array_compute_test.cc DEPS ${test_kernel_deps}) lite_cc_test(test_kernel_read_from_array_compute SRCS read_from_array_compute_test.cc DEPS ${test_kernel_deps}) diff --git a/lite/tests/kernels/topk_v2_compute_test.cc b/lite/tests/kernels/topk_v2_compute_test.cc index b97920625d4..4f0808e1fdb 100644 --- a/lite/tests/kernels/topk_v2_compute_test.cc +++ b/lite/tests/kernels/topk_v2_compute_test.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -48,11 +48,13 @@ class TopkV2ComputeTester : public arena::TestCase { void RunBaseline(Scope* scope) override { auto* out_val = scope->NewTensor(out_); auto* out_ind = scope->NewTensor(indices_); + DDim out_dims = x_dims_; if (axis_ < 0) { axis_ += x_dims_.size(); } out_dims[axis_] = k_; + out_val->Resize(out_dims); out_ind->Resize(out_dims); auto* out_val_data = out_val->template mutable_data(); @@ -60,24 +62,27 @@ class TopkV2ComputeTester : public arena::TestCase { auto* x = scope->FindTensor(x_); const auto* x_data = x->template data(); + int inner_size = x_dims_.count(axis_ + 1, x_dims_.size()); int axis_size = x_dims_[axis_]; int outer_size = x_dims_.count(0, axis_); - int out_sum_size = k * inner_size; - for (int n = 0; n < outer_size; n++) { - const float* in_data = x_data + n * sum_size; - float* out_val_data1 = out_val_data + n * out_sum_size; - int64_t* out_ind_data1 = out_ind_data + n * out_sum_size; - for (int i = 0; i < inner_size; i++) { - std::vector> vec; - for (int j = 0; j < axis_size; j++) { - vec.push_back(std::make_pair(in_data[j * outer_size + i], j)); - } - std::partial_sort( - vec.begin(), vec.begin() + k_, vec.end(), comp_func); - for (int j = 0; j < k_; j++) { - out_val_data1[j * outer_size + i] = vec[j].first; - out_ind_data1[j * outer_size + i] = vec[j].second; + int out_sum_size = k_ * inner_size; + int sum_size = axis_size * inner_size; + + for (int i = 0; i < outer_size; i++) { + int glb_in_off = i * sum_size; + int glb_out_off = i * out_sum_size; + std::vector> vec; + for (int j = 0; j < axis_size; j++) { + vec.push_back(std::make_pair(x_data[glb_in_off + j * inner_size], j)); + } + std::partial_sort( + vec.begin(), vec.begin() + k_, vec.end(), comp_func); + for (int j = 0; j < k_; j++) { + for (int k = 0; k < inner_size; k++) { + int cur_off = glb_in_off + vec[j].second * inner_size + k; + out_val_data[glb_out_off + j * inner_size + k] = x_data[cur_off]; + out_ind_data[glb_out_off + j * inner_size + k] = vec[j].second; } } } @@ -101,10 +106,27 @@ class TopkV2ComputeTester : public arena::TestCase { template void test_topk_v2(Place place, float abs_error) { + int caseNum = 0; for (auto x_shape : std::vector>{{2, 3, 4, 5}, {3, 4, 5}, {4, 5}}) { for (int axis : {-1, -2}) { for (int k : {2, 5}) { + std::cout << "start case " << caseNum++ << ":" << std::endl; + auto axis_valid = ((axis >= (-1 * (int)x_shape.size())) && + (axis < (int)x_shape.size())); + if (!axis_valid) { + LOG(INFO) << "the axis of topk_v2 must be [" << (-1 * x_shape.size()) + << ", " << x_shape.size() << "but you set axis is" << axis; + continue; + } + if (axis < 0) { + axis += x_shape.size(); + } + if (x_shape[axis] < k) { + LOG(INFO) << "input of topk_v2 op must have >=" << k + << " columns in axis of " << x_shape[axis]; + continue; + } std::unique_ptr tester(new TopkV2ComputeTester( place, "def", DDim(x_shape), axis, k)); arena::Arena arena(std::move(tester), place, abs_error); @@ -116,9 +138,9 @@ void test_topk_v2(Place place, float abs_error) { TEST(Topk, precision) { Place place; - float abs_error = 2e-5; #if defined(LITE_WITH_ARM) place = TARGET(kHost); + float abs_error = 2e-5; test_topk_v2(place, abs_error); #else return;