Skip to content

Commit

Permalink
[XPU] fix some bugs for transformer (PaddlePaddle#7014)
Browse files Browse the repository at this point in the history
  • Loading branch information
wanglei91 authored and newway committed Dec 24, 2021
1 parent 8b98474 commit 17ed939
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 1 deletion.
16 changes: 16 additions & 0 deletions lite/kernels/host/compare_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,22 @@ REGISTER_LITE_KERNEL(greater_than, kHost, kFloat, kAny, greater_than_float, def)
.BindPaddleOpVersion("greater_than", 1)
.Finalize();

using greater_than_bool = paddle::lite::kernels::host::CompareCompute<
PRECISION(kFloat),
paddle::lite::kernels::host::_GreaterThanFunctor<bool>>;
REGISTER_LITE_KERNEL(greater_than, kHost, kFloat, kAny, greater_than_bool, bool)
.BindInput("X",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)})
.BindInput("Y",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)})
.BindOutput("Out",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)})
.BindPaddleOpVersion("greater_than", 1)
.Finalize();

using greater_than_int32 = paddle::lite::kernels::host::CompareCompute<
PRECISION(kInt32),
paddle::lite::kernels::host::_GreaterThanFunctor<int32_t>>;
Expand Down
3 changes: 3 additions & 0 deletions lite/kernels/host/gather_nd_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ void GatherNdCompute::Run() {
case PRECISION(kUInt8): \
GatherNd<uint8_t, index_data_type>(*x, *index, out); \
break; \
case PRECISION(kInt8): \
GatherNd<int8_t, index_data_type>(*x, *index, out); \
break; \
case PRECISION(kBool): \
GatherNd<bool, index_data_type>(*x, *index, out); \
break; \
Expand Down
14 changes: 13 additions & 1 deletion lite/kernels/host/tile_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,19 @@ REGISTER_LITE_KERNEL(tile, kHost, kInt64, kNCHW, tile_int64, def_int64)
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt64))})
.Finalize();

#ifdef LITE_BUILD_EXTRA
using tile_int64_f =
paddle::lite::kernels::host::TileCompute<int64_t, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(tile, kHost, kFloat, kNCHW, tile_int64_f, def_int64)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt64))})
.BindInput("RepeatTimes",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindInput("repeat_times_tensor",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt64))})
.Finalize();
#endif // LITE_BUILD_EXTRA
using tile_int8 =
paddle::lite::kernels::host::TileCompute<int8_t, PRECISION(kInt8)>;
REGISTER_LITE_KERNEL(tile, kHost, kInt8, kNCHW, tile_int8, def_int8)
Expand Down
7 changes: 7 additions & 0 deletions lite/kernels/xpu/assign_value_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ void AssignValueCompute::Run() {
int dtype = param.dtype;
std::vector<float> fp32_values = param.fp32_values;
std::vector<int> int32_values = param.int32_values;
std::vector<int64_t> int64_values = param.int64_values;
CHECK_GT(param.shape.size(), 0UL);
if (dtype == static_cast<int>(lite::core::FluidType::INT32)) {
auto* out = param.Out->mutable_data<int>(TARGET(kXPU));
Expand All @@ -40,6 +41,12 @@ void AssignValueCompute::Run() {
fp32_values.data(),
sizeof(float) * fp32_values.size(),
IoDirection::HtoD);
} else if (dtype == static_cast<int>(lite::core::FluidType::INT64)) {
auto* out = param.Out->mutable_data<int64_t>(TARGET(kXPU));
lite::TargetWrapperXPU::MemcpySync(out,
int64_values.data(),
sizeof(int64_t) * int64_values.size(),
IoDirection::HtoD);
} else {
LOG(FATAL) << "Unsupported dtype for assign_value_op:" << dtype;
}
Expand Down
8 changes: 8 additions & 0 deletions lite/kernels/xpu/fill_constant_batch_size_like_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ void FillConstantBatchSizeLikeCompute::Run() {
}
int r = 0;
switch (param.dtype) {
case 0: {
auto data = param.out->mutable_data<bool>(TARGET(kXPU));
r = xdnn::constant<bool>(ctx.GetRawContext(),
data,
write_size,
static_cast<bool>(param.value));
break;
}
case 1: {
auto data = param.out->mutable_data<int16_t>(TARGET(kXPU));
r = xdnn::constant<int16_t>(ctx.GetRawContext(),
Expand Down
15 changes: 15 additions & 0 deletions lite/kernels/xpu/transpose_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ void TransposeCompute::Run() {
int ndims = axis.size();
const auto x_dims = x->dims();
std::vector<int> x_shape_host(ndims, 0);
if (x_dims.production() == 0) {
param.output->set_target(TARGET(kXPU));
return;
}

for (int i = 0; i < ndims; ++i) {
x_shape_host[i] = x_dims[i];
Expand Down Expand Up @@ -68,3 +72,14 @@ REGISTER_LITE_KERNEL(transpose2,
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kHost))})
.Finalize();
REGISTER_LITE_KERNEL(transpose2,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::TransposeCompute,
def_int64)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})
.BindOutput("XShape",
{LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})
.Finalize();

0 comments on commit 17ed939

Please sign in to comment.