Skip to content

Commit

Permalink
fix bugs in topk_v2 kernel test=develop
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenlin-work committed Aug 6, 2021
1 parent 33d00c5 commit 682088a
Showing 1 changed file with 16 additions and 14 deletions.
30 changes: 16 additions & 14 deletions lite/kernels/host/topk_v2_compute.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -44,19 +44,21 @@ void TopkV2Compute::Run() {
int inner_size = x_dims.count(axis + 1, dim_size);
int sum_size = axis_size * inner_size;
int out_sum_size = k * inner_size;
for (int n = 0; n < outer_size; n++) {
const float* in_data = x_data + n * sum_size;
float* out_data = out_val + n * out_sum_size;
int64_t* out_ind_data = out_ind + n * out_sum_size;
for (int i = 0; i < inner_size; i++) {
std::vector<std::pair<float, int>> vec;
for (int j = 0; j < axis_size; j++) {
vec.push_back(std::make_pair(in_data[j * inner_size + i], j));
}
std::partial_sort(vec.begin(), vec.begin() + k, vec.end(), comp_func);
for (int j = 0; j < k; j++) {
out_data[j * inner_size + i] = vec[j].first;
out_ind_data[j * inner_size + i] = vec[j].second;

for (int i = 0; i < outer_size; i++) {
int glb_in_off = i * sum_size;
int glb_out_off = i * out_sum_size;
std::vector<std::pair<float, int>> vec;
for (int j = 0; j < axis_size; j++) {
vec.push_back(std::make_pair(x_data[glb_in_off + j * inner_size], j));
}
std::partial_sort(
vec.begin(), vec.begin() + k, vec.end(), comp_func);
for (int j = 0; j < k; j++) {
for (int k = 0; k < inner_size; k++) {
int cur_off = glb_in_off + vec[j].second * inner_size + k;
out_val[glb_out_off + j * inner_size + k] = x_data[cur_off];
out_ind[glb_out_off + j * inner_size + k] = vec[j].second;
}
}
}
Expand Down

0 comments on commit 682088a

Please sign in to comment.