Skip to content

Commit

Permalink
[Host] move beam_search (#5759)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhupengyang authored Mar 23, 2021
1 parent 6c76f3c commit 22fca16
Show file tree
Hide file tree
Showing 14 changed files with 78 additions and 658 deletions.
1 change: 0 additions & 1 deletion lite/backends/arm/math/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
norm.cc
pad2d.cc
negative.cc
beam_search.cc
reduce_max.cc
reduce_min.cc
reduce_max_min.cc
Expand Down
1 change: 0 additions & 1 deletion lite/backends/arm/math/funcs.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include "lite/backends/arm/math/anchor_generator.h"
#include "lite/backends/arm/math/argmax.h"
#include "lite/backends/arm/math/axpy.h"
#include "lite/backends/arm/math/beam_search.h"
#include "lite/backends/arm/math/box_coder.h"
#include "lite/backends/arm/math/clip.h"
#include "lite/backends/arm/math/col_im_transform.h"
Expand Down
1 change: 1 addition & 0 deletions lite/backends/host/math/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
lite_cc_library(math_host SRCS
beam_search.cc
sequence_padding.cc
slice.cc
pad3d.cc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/backends/arm/math/beam_search.h"
#include <arm_neon.h>
#include "lite/backends/host/math/beam_search.h"
#include <cmath>
#include <string>
#include <vector>
#include "lite/utils/cp_logging.h"

namespace paddle {
namespace lite {
namespace arm {
namespace host {
namespace math {
/*
* The basic items help to sort.
Expand Down Expand Up @@ -207,9 +205,7 @@ void beam_search(const Tensor *pre_ids,
int level,
int beam_size,
int end_id,
bool is_accumulated,
Context<TARGET(kARM)> *ctx) {
// auto abs_lod = framework::ToAbsOffset(scores->lod());
bool is_accumulated) {
auto abs_lod = scores->lod();
auto &high_level = abs_lod[level];
auto items = SelectTopBeamSizeItems(pre_ids,
Expand Down Expand Up @@ -266,6 +262,6 @@ void beam_search(const Tensor *pre_ids,
}

} // namespace math
} // namespace arm
} // namespace host
} // namespace lite
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@
// limitations under the License.

#pragma once

#include <cmath>
#include "lite/core/context.h"

namespace paddle {
namespace lite {
namespace arm {
namespace host {
namespace math {

void beam_search(const Tensor* pre_ids,
Expand All @@ -32,10 +30,9 @@ void beam_search(const Tensor* pre_ids,
int level,
int beam_size,
int end_id,
bool is_accumulated,
Context<TARGET(kARM)>* ctx);
bool is_accumulated);

} // namespace math
} // namespace arm
} // namespace host
} // namespace lite
} // namespace paddle
2 changes: 0 additions & 2 deletions lite/backends/x86/math/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ math_library(sequence2batch)
math_library(sequence_pooling DEPS math_function jit_kernel_helper)
math_library(sequence_scale)
math_library(softmax DEPS math_function jit_kernel_helper)
math_library(beam_search DEPS math_function)
#
## math_library(matrix_bit_code)
#
Expand All @@ -90,7 +89,6 @@ endif()
# cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col)
# cc_test(sequence_padding_test SRCS sequence_padding_test.cc DEPS sequence_padding)
# cc_test(sequence_pooling_test SRCS sequence_pooling_test.cc DEPS sequence_pooling)
# cc_test(beam_search_test SRCS beam_search_test.cc DEPS beam_search)
# cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split)
# cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
math_library(box_coder DEPS math_function)
Expand Down
Loading

0 comments on commit 22fca16

Please sign in to comment.