Skip to content

Commit

Permalink
fixed the bug of tile op in large input and add xpu implemention.
Browse files Browse the repository at this point in the history
  • Loading branch information
wbn03 committed Jun 22, 2022
1 parent c9d1213 commit f95d41b
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 3 deletions.
7 changes: 4 additions & 3 deletions lite/kernels/host/tile_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,10 @@ void TileCompute<T, PType>::Run() {
int dst_stride = in_stride[i + 1] * right;
for (int m = 0; m < num; m++) {
for (int j = 0; j < bcast_dims[i]; j++) {
std::memcpy(tmp_dst + j * dst_stride / bcast_dims[i] + m * dst_stride,
tmp_src + m * dst_stride / bcast_dims[i],
dst_stride / bcast_dims[i] * sizeof(T));
std::memcpy(
tmp_dst + j * (dst_stride / bcast_dims[i]) + m * dst_stride,
tmp_src + m * (dst_stride / bcast_dims[i]),
dst_stride / bcast_dims[i] * sizeof(T));
}
}
tmp_src_tensor.CopyDataFrom(tmp_dst_tensor);
Expand Down
1 change: 1 addition & 0 deletions lite/kernels/xpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ add_kernel(gru_compute_xpu XPU basic SRCS gru_compute.cc)
add_kernel(gru_unit_compute_xpu XPU basic SRCS gru_unit_compute.cc)
add_kernel(stack_compute_xpu XPU basic SRCS stack_compute.cc)
add_kernel(slice_compute_xpu XPU basic SRCS slice_compute.cc)
add_kernel(tile_compute_xpu XPU basic SRCS tile_compute.cc)
add_kernel(cast_compute_xpu XPU basic SRCS cast_compute.cc)
add_kernel(sequence_topk_avg_pooling_compute_xpu XPU basic SRCS sequence_topk_avg_pooling_compute.cc)
add_kernel(concat_compute_xpu XPU basic SRCS concat_compute.cc)
Expand Down
78 changes: 78 additions & 0 deletions lite/kernels/xpu/tile_compute.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/kernels/xpu/tile_compute.h"
#include <vector>
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {

template <typename T, PrecisionType PType>
void TileCompute<T, PType>::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<XPUContext>();
auto repeat_times = param.repeat_times;
if (param.RepeatTimes) {
auto repeat_times_size = param.RepeatTimes->data_size();
for (int64_t i = 0; i < repeat_times_size; i++) {
repeat_times.push_back(param.RepeatTimes->template data<int>()[i]);
}
} else if (param.repeat_times_tensor.size() != 0) {
for (int i = 0; i < param.repeat_times_tensor.size(); i++) {
auto temp = param.repeat_times_tensor[i];
repeat_times.push_back(*(temp->template data<int>()));
}
}
auto in_dims = param.X->dims();
auto vec_in_dims = in_dims.Vectorize();
// broadcast for vec_in_dims.size() equal to repeat_times.size()
if (repeat_times.size() < vec_in_dims.size()) {
int diff = vec_in_dims.size() - repeat_times.size();
repeat_times.insert(repeat_times.begin(), diff, 1);
} else {
int diff = repeat_times.size() - vec_in_dims.size();
vec_in_dims.insert(vec_in_dims.begin(), diff, 1);
}

std::vector<int> new_in_dims(vec_in_dims.begin(), vec_in_dims.end());
std::vector<int> out_dims(param.Out->dims().data().begin(),
param.Out->dims().data().end());
int r = xdnn::broadcast<T>(ctx.GetRawContext(),
param.X->template data<T>(),
param.Out->template mutable_data<T>(TARGET(kXPU)),
new_in_dims,
out_dims);

CHECK_EQ(r, 0);
}

} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle

using tile_float =
paddle::lite::kernels::xpu::TileCompute<float, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(tile, kXPU, kFloat, kNCHW, tile_float, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("RepeatTimes",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindInput("repeat_times_tensor",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
36 changes: 36 additions & 0 deletions lite/kernels/xpu/tile_compute.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "lite/core/kernel.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {

template <typename T, PrecisionType PType>
class TileCompute : public KernelLite<TARGET(kXPU), PType> {
public:
using param_t = operators::TileParam;

virtual void Run();

virtual ~TileCompute() = default;
};

} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
1 change: 1 addition & 0 deletions lite/operators/tile_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ bool TileOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
} else if (opdesc.HasInput("repeat_times_tensor") &&
(opdesc.Input("repeat_times_tensor").size() != 0)) {
auto temp = opdesc.Input("repeat_times_tensor");
param_.repeat_times_tensor.clear();
for (auto var : temp) {
param_.repeat_times_tensor.push_back(
scope->FindVar(var)->GetMutable<lite::Tensor>());
Expand Down
3 changes: 3 additions & 0 deletions lite/tests/kernels/tile_compute_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@ TEST(tile, precision) {
#else
return;
#endif
#elif defined(LITE_WITH_XPU)
place = TARGET(kXPU);
alias = "def";
#elif defined(LITE_WITH_ARM) || defined(LITE_WITH_X86)
place = TARGET(kHost);
#else
Expand Down

0 comments on commit f95d41b

Please sign in to comment.