diff --git a/lite/kernels/host/topk_v2_compute.cc b/lite/kernels/host/topk_v2_compute.cc index 896d5b863f5..539e768e555 100644 --- a/lite/kernels/host/topk_v2_compute.cc +++ b/lite/kernels/host/topk_v2_compute.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -44,19 +44,21 @@ void TopkV2Compute::Run() { int inner_size = x_dims.count(axis + 1, dim_size); int sum_size = axis_size * inner_size; int out_sum_size = k * inner_size; - for (int n = 0; n < outer_size; n++) { - const float* in_data = x_data + n * sum_size; - float* out_data = out_val + n * out_sum_size; - int64_t* out_ind_data = out_ind + n * out_sum_size; - for (int i = 0; i < inner_size; i++) { - std::vector> vec; - for (int j = 0; j < axis_size; j++) { - vec.push_back(std::make_pair(in_data[j * inner_size + i], j)); - } - std::partial_sort(vec.begin(), vec.begin() + k, vec.end(), comp_func); - for (int j = 0; j < k; j++) { - out_data[j * inner_size + i] = vec[j].first; - out_ind_data[j * inner_size + i] = vec[j].second; + + for (int i = 0; i < outer_size; i++) { + int glb_in_off = i * sum_size; + int glb_out_off = i * out_sum_size; + std::vector> vec; + for (int j = 0; j < axis_size; j++) { + vec.push_back(std::make_pair(x_data[glb_in_off + j * inner_size], j)); + } + std::partial_sort( + vec.begin(), vec.begin() + k, vec.end(), comp_func); + for (int j = 0; j < k; j++) { + for (int k = 0; k < inner_size; k++) { + int cur_off = glb_in_off + vec[j].second * inner_size + k; + out_val[glb_out_off + j * inner_size + k] = x_data[cur_off]; + out_ind[glb_out_off + j * inner_size + k] = vec[j].second; } } }