diff --git a/lite/kernels/host/compare_compute.cc b/lite/kernels/host/compare_compute.cc index 25fad84a5ce..3689dc4f206 100644 --- a/lite/kernels/host/compare_compute.cc +++ b/lite/kernels/host/compare_compute.cc @@ -147,6 +147,24 @@ REGISTER_LITE_KERNEL(equal, kHost, kInt64, kAny, equal_int64, def) .BindPaddleOpVersion("equal", 1) .Finalize(); +// float kernel has higher score when picking kernel. +// TODO(zhupengyang): merge equal_int64 later +using equal_int64_f = paddle::lite::kernels::host::CompareCompute< + PRECISION(kFloat), + paddle::lite::kernels::host::_EqualFunctor>; +REGISTER_LITE_KERNEL(equal, kHost, kFloat, kAny, equal_int64_f, int64) + .BindInput("X", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kInt64), DATALAYOUT(kAny), -1)}) + .BindInput("Y", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kInt64), DATALAYOUT(kAny), -1)}) + .BindOutput("Out", + {LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)}) + .BindPaddleOpVersion("equal", 1) + .Finalize(); + using equal_int32 = paddle::lite::kernels::host::CompareCompute< PRECISION(kInt32), paddle::lite::kernels::host::_EqualFunctor>; diff --git a/lite/kernels/host/range_compute.cc b/lite/kernels/host/range_compute.cc index 4923492cf92..10f354cd70d 100644 --- a/lite/kernels/host/range_compute.cc +++ b/lite/kernels/host/range_compute.cc @@ -100,3 +100,25 @@ REGISTER_LITE_KERNEL(range, kHost, kInt32, kAny, range_int32, def) PRECISION(kInt32), DATALAYOUT(kAny))}) .Finalize(); + +// float kernel has higher score when picking kernel. +using range_int32_f = + paddle::lite::kernels::host::RangeCompute; +REGISTER_LITE_KERNEL(range, kHost, kFloat, kAny, range_int32_f, int32) + .BindInput("Start", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kInt32), + DATALAYOUT(kAny))}) + .BindInput("End", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kInt32), + DATALAYOUT(kAny))}) + .BindInput("Step", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kInt32), + DATALAYOUT(kAny))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kInt32), + DATALAYOUT(kAny))}) + .Finalize(); diff --git a/lite/kernels/x86/cast_compute.cc b/lite/kernels/x86/cast_compute.cc index bbb63e59526..efbdd4d526e 100644 --- a/lite/kernels/x86/cast_compute.cc +++ b/lite/kernels/x86/cast_compute.cc @@ -34,3 +34,23 @@ REGISTER_LITE_KERNEL( .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFP16))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); + +REGISTER_LITE_KERNEL(cast, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::CastCompute, + bool_to_any) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kBool))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kAny))}) + .Finalize(); + +REGISTER_LITE_KERNEL(cast, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::CastCompute, + int32_to_any) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kAny))}) + .Finalize(); diff --git a/lite/kernels/x86/slice_compute.cc b/lite/kernels/x86/slice_compute.cc index 0291632e083..437f24b2303 100644 --- a/lite/kernels/x86/slice_compute.cc +++ b/lite/kernels/x86/slice_compute.cc @@ -27,3 +27,18 @@ REGISTER_LITE_KERNEL(slice, .BindInput("EndsTensorList", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); + +REGISTER_LITE_KERNEL(slice, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::SliceCompute, + int32) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) + .BindInput("StartsTensor", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("EndsTensor", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("StartsTensorList", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("EndsTensorList", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) + .Finalize(); diff --git a/lite/kernels/xpu/reshape_compute.cc b/lite/kernels/xpu/reshape_compute.cc index d70001085bb..c27f5f89310 100644 --- a/lite/kernels/xpu/reshape_compute.cc +++ b/lite/kernels/xpu/reshape_compute.cc @@ -11,6 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + #include "lite/kernels/xpu/reshape_compute.h" #include #include "lite/backends/xpu/xpu_header_sitter.h" @@ -21,9 +22,10 @@ namespace lite { namespace kernels { namespace xpu { -void ReshapeCompute::Run() { - auto& param = this->Param(); - auto& ctx = this->ctx_->As(); +template +void ReshapeCompute::Run() { + auto& param = this->template Param(); + auto& ctx = this->ctx_->template As(); auto x = param.x; auto output = param.output; auto output_dims = output->dims(); @@ -32,10 +34,10 @@ void ReshapeCompute::Run() { output->ShareDataWith(*x); output->Resize(output_dims); } else { - int r = xdnn::copy(ctx.GetRawContext(), - param.x->data(), - param.output->mutable_data(TARGET(kXPU)), - param.x->numel()); + int r = xdnn::copy(ctx.GetRawContext(), + x->template data(), + output->template mutable_data(TARGET(kXPU)), + x->numel()); CHECK_EQ(r, 0); } @@ -50,24 +52,58 @@ REGISTER_LITE_KERNEL(reshape2, kXPU, kFloat, kNCHW, - paddle::lite::kernels::xpu::ReshapeCompute, - def) + paddle::lite::kernels::xpu::ReshapeCompute, + float32) .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))}) - .BindInput("ShapeTensor", {LiteType::GetTensorTy(TARGET(kHost))}) - .BindInput("Shape", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindInput("ShapeTensor", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) + .BindInput("Shape", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kHost))}) .Finalize(); +REGISTER_LITE_KERNEL(reshape2, + kXPU, + kFloat, + kNCHW, + paddle::lite::kernels::xpu::ReshapeCompute, + int32) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))}) + .BindInput("ShapeTensor", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) + .BindInput("Shape", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))}) + .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kHost))}) + .Finalize(); + +REGISTER_LITE_KERNEL(reshape2, + kXPU, + kFloat, + kNCHW, + paddle::lite::kernels::xpu::ReshapeCompute, + int64) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))}) + .BindInput("ShapeTensor", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) + .BindInput("Shape", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))}) + .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kHost))}) + .Finalize(); + REGISTER_LITE_KERNEL(reshape, kXPU, kFloat, kNCHW, - paddle::lite::kernels::xpu::ReshapeCompute, - def) + paddle::lite::kernels::xpu::ReshapeCompute, + float32) .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))}) - .BindInput("ShapeTensor", {LiteType::GetTensorTy(TARGET(kHost))}) - .BindInput("Shape", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindInput("ShapeTensor", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) + .BindInput("Shape", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) .Finalize(); @@ -75,10 +111,11 @@ REGISTER_LITE_KERNEL(flatten, kXPU, kFloat, kNCHW, - paddle::lite::kernels::xpu::ReshapeCompute, - def) + paddle::lite::kernels::xpu::ReshapeCompute, + float32) .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))}) - .BindInput("Shape", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindInput("Shape", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) .Finalize(); @@ -86,10 +123,11 @@ REGISTER_LITE_KERNEL(flatten2, kXPU, kFloat, kNCHW, - paddle::lite::kernels::xpu::ReshapeCompute, - def) + paddle::lite::kernels::xpu::ReshapeCompute, + float32) .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))}) - .BindInput("Shape", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindInput("Shape", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kHost))}) .Finalize(); diff --git a/lite/kernels/xpu/reshape_compute.h b/lite/kernels/xpu/reshape_compute.h index 92d1a964060..28ce4f8f32b 100644 --- a/lite/kernels/xpu/reshape_compute.h +++ b/lite/kernels/xpu/reshape_compute.h @@ -21,6 +21,7 @@ namespace lite { namespace kernels { namespace xpu { +template class ReshapeCompute : public KernelLite { public: using param_t = operators::ReshapeParam;