From c1dba6b1a43d095d70435ab43143fd1f3af1fc23 Mon Sep 17 00:00:00 2001 From: sprouteer Date: Thu, 30 Jun 2022 01:09:51 +0800 Subject: [PATCH 1/4] add reshape buffer opencl test=develop --- lite/kernels/opencl/CMakeLists.txt | 4 + lite/kernels/opencl/reshape_buffer_compute.cc | 114 +++++++++ .../opencl/reshape_buffer_compute_test.cc | 221 ++++++++++++++++++ 3 files changed, 339 insertions(+) create mode 100644 lite/kernels/opencl/reshape_buffer_compute.cc create mode 100644 lite/kernels/opencl/reshape_buffer_compute_test.cc diff --git a/lite/kernels/opencl/CMakeLists.txt b/lite/kernels/opencl/CMakeLists.txt index cfa99d32a51..6e799d9f0d9 100644 --- a/lite/kernels/opencl/CMakeLists.txt +++ b/lite/kernels/opencl/CMakeLists.txt @@ -178,6 +178,7 @@ add_kernel(slice_opencl_buffer OPENCL basic SRCS slice_buffer_compute.cc) add_kernel(yolo_box_opencl_buffer OPENCL basic SRCS yolo_box_buffer_compute.cc) add_kernel(squeeze_unsqueeze_opencl_buffer OPENCL basic SRCS squeeze_unsqueeze_buffer_compute.cc) add_kernel(matmul_opencl_buffer OPENCL basic SRCS matmul_buffer_compute.cc) +add_kernel(reshape_opencl_buffer OPENCL basic SRCS reshape_buffer_compute.cc) # extra # wait to add ... @@ -215,3 +216,6 @@ lite_cc_test(test_fc_buffer_opencl SRCS fc_buffer_compute_test.cc lite_cc_test(test_io_copy_buffer_opencl SRCS io_copy_buffer_compute_test.cc DEPS kernels core) + +lite_cc_test(test_reshape_buffer_opencl SRCS reshape_buffer_compute_test.cc + DEPS kernels core) diff --git a/lite/kernels/opencl/reshape_buffer_compute.cc b/lite/kernels/opencl/reshape_buffer_compute.cc new file mode 100644 index 00000000000..a45e606a92d --- /dev/null +++ b/lite/kernels/opencl/reshape_buffer_compute.cc @@ -0,0 +1,114 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/opencl/cl_half.h" +#include "lite/backends/opencl/cl_include.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/kernels/opencl/image_helper.h" +#include "lite/operators/op_params.h" +#include "lite/utils/log/logging.h" +#include "lite/utils/replace_stl/stream.h" +#ifdef LITE_WITH_PROFILE +#include "lite/core/profile/profiler.h" +#endif +#include "lite/backends/opencl/cl_utility.h" + +#undef LITE_WITH_LOG + +namespace paddle { +namespace lite { +namespace kernels { +namespace opencl { + +class ReshapeComputeFloatBuffer + : public KernelLite { + public: + using param_t = operators::ReshapeParam; + + void PrepareForRun() override { auto& context = ctx_->As(); } + + void Run() override { + auto& param = *param_.get_mutable(); + const Tensor* const x = param.x; + Tensor* const output = param.output; + + auto output_dims = output->dims(); + auto output_lod = output->lod(); + if (param.inplace) { + output->ShareDataWith(*x); + } else { + output->CopyDataFrom(*x); + } + output->Resize(output_dims); + output->set_lod(output_lod); + +#ifdef LITE_WITH_LOG + VLOG(4) << TargetToStr(x->target()); + VLOG(4) << TargetToStr(param.output->target()); +#endif + } + + private: + std::string time_stamp_{GetTimeStamp()}; +}; + +} // namespace opencl +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(reshape, + kOpenCL, + kFP16, + kNCHW, + paddle::lite::kernels::opencl::ReshapeComputeFloatBuffer, + def) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kNCHW))}) + .BindInput("ShapeTensor", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) + .BindInput("Shape", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kNCHW))}) + .Finalize(); + +REGISTER_LITE_KERNEL(reshape2, + kOpenCL, + kFP16, + kNCHW, + paddle::lite::kernels::opencl::ReshapeComputeFloatBuffer, + def) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kNCHW))}) + .BindInput("ShapeTensor", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) + .BindInput("Shape", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) + .BindOutput("XShape", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kNCHW))}) + .Finalize(); + +#define LITE_WITH_LOG diff --git a/lite/kernels/opencl/reshape_buffer_compute_test.cc b/lite/kernels/opencl/reshape_buffer_compute_test.cc new file mode 100644 index 00000000000..a6b93ce77e3 --- /dev/null +++ b/lite/kernels/opencl/reshape_buffer_compute_test.cc @@ -0,0 +1,221 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/backends/opencl/target_wrapper.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" +#include "lite/kernels/opencl/test_helper.h" +#include "lite/operators/reshape_op.h" +#include "lite/utils/log/logging.h" + +#define FP16_MAX_DIFF (5e-1) + +namespace paddle { +namespace lite { +namespace kernels { +namespace opencl { +static DDim ValidateShape(const std::vector& shape, + const DDim& input_dims) { + const lite::DDim::value_type input_size = input_dims.production(); + auto input_shape = input_dims.Vectorize(); + bool all_positive = std::all_of( + input_shape.cbegin(), input_shape.cend(), [](lite::DDim::value_type i) { + return i > 0; + }); + // only one dimension can be set to -1, whose size will be automatically + // infered. + const int unk_dim_val = -1; + const int copy_dim_val = 0; + + std::vector output_shape(shape.size(), 0); + lite::DDim::value_type capacity = 1; + int unk_dim_idx = -1; + for (size_t i = 0; i < shape.size(); ++i) { + if (shape[i] == unk_dim_val) { + CHECK_EQ(unk_dim_idx, -1) + << "Only one input dimension of Attr(shape) can be unknown."; + unk_dim_idx = i; + } else if (shape[i] == copy_dim_val) { + CHECK_LT(static_cast(i), input_shape.size()) + << "The index of dimension to copy from input shape must be less " + "than the size of input shape."; + } else { + CHECK_GT(shape[i], 0) << "Each input dimension of Attr(shape) must not " + "be negtive except one unknown dimension."; + } + + capacity *= (shape[i] ? static_cast(shape[i]) + : input_shape[i]); + output_shape[i] = (shape[i] ? static_cast(shape[i]) + : input_shape[i]); + } + + if (unk_dim_idx != -1) { + if (all_positive) { + // input_size < 0 and is un-determinate in compile time, skip the check, + // for example, input_dims = [-1, 8, 1, 1], shape = [-1, 3, 8], + // capacity = -24, input_size = -8, output_shape[0] = 0 + // the following check will fail. + output_shape[unk_dim_idx] = -input_size / capacity; + CHECK_EQ(output_shape[unk_dim_idx] * capacity, -input_size) + << "Invalid shape is given."; + } else { + output_shape[unk_dim_idx] = -1; + } + } else { + CHECK_EQ(capacity, input_size) << "Invalid shape is given."; + } + return lite::DDim(output_shape); +} + +TEST(reshape_opencl, compute) { + LOG(INFO) << "to get kernel ..."; + auto kernels = KernelRegistry::Global().Create( + "reshape", TARGET(kOpenCL), PRECISION(kFP16), DATALAYOUT(kNCHW)); + ASSERT_FALSE(kernels.empty()); + auto kernel = std::move(kernels.front()); + lite_api::CLPrecisionType p = lite_api::CLPrecisionType::CL_PRECISION_FP16; + CLRuntime::Global()->set_precision(p); + const bool fp16_flag = (p == lite_api::CLPrecisionType::CL_PRECISION_FP16); + LOG(INFO) << "created reshape kernel"; + + LOG(INFO) << "prepare kernel ------"; + + int64_t batch_size = 15; + int64_t ic = 1; + int64_t ih = 2; + int64_t iw = 3; + + lite::Tensor input, output, input_h; + + operators::ReshapeParam param; + + Tensor shape_tensor; + shape_tensor.Resize({3}); + auto* shape_tensor_data = shape_tensor.mutable_data(); + shape_tensor_data[0] = 1; + shape_tensor_data[1] = 15; + shape_tensor_data[2] = 6; + + if (fp16_flag) { + param.x = &input_h; + param.shape_tensor = &shape_tensor; // use shape_tensor + param.inplace = true; + param.output = &output; + } else { + param.x = &input; + param.shape_tensor = &shape_tensor; // use shape_tensor + param.inplace = true; + param.output = &output; + } + + const DDim input_dim = + lite::DDim{std::vector({batch_size, ic, ih, iw})}; + input.Resize(input_dim); + input_h.Resize(input_dim); + + std::vector final_shape = std::vector( + shape_tensor_data, shape_tensor_data + shape_tensor.numel()); + LOG(INFO) << "shape_tensor.numel() " << shape_tensor.numel(); + auto out_dim = ValidateShape(final_shape, input_dim); + param.output->Resize(out_dim); + LOG(INFO) << " out_dim------" << out_dim; + + LOG(INFO) << "prepare kernel SetParam------"; + kernel->SetParam(param); + std::unique_ptr context(new KernelContext); + context->As().InitOnce(); + kernel->SetContext(std::move(context)); + + auto* input_data_h = + input_h.mutable_data(TARGET(kOpenCL)); + auto* input_data = input.mutable_data(TARGET(kOpenCL)); + + std::default_random_engine engine; + std::uniform_real_distribution dist(-5, 5); + LOG(INFO) << "gen input ..."; + std::vector x_source(input_dim.production()); + std::vector x_source_half(input_dim.production()); + for (size_t i = 0; i < input_dim.production(); ++i) { + x_source[i] = static_cast(dist(engine)); + x_source_half[i] = Float2Half(x_source[i]); + } + + size_t x_size = input_dim.production() * sizeof(float); + if (fp16_flag) { + x_size = input_dim.production() * sizeof(half_t); + TargetWrapperCL::MemcpySync( + input_data_h, x_source_half.data(), x_size, IoDirection::HtoD); + } else { + TargetWrapperCL::MemcpySync( + input_data, x_source.data(), x_size, IoDirection::HtoD); + } + + kernel->Launch(); + CLRuntime::Global()->command_queue().finish(); + auto* y_buffer = fp16_flag ? output.data() + : output.data(); + std::vector out_data_from_gpu(out_dim.production()); + std::vector output_half2float(out_dim.production()); + std::vector out_data_from_gpu_half(out_dim.production()); + if (fp16_flag) { + TargetWrapperCL::MemcpySync(out_data_from_gpu_half.data(), + y_buffer, + out_data_from_gpu_half.size() * sizeof(half_t), + IoDirection::DtoH); + } else { + TargetWrapperCL::MemcpySync(out_data_from_gpu.data(), + y_buffer, + out_data_from_gpu.size() * sizeof(float), + IoDirection::DtoH); + } + for (int eidx = 0; eidx < out_dim.production(); ++eidx) { + output_half2float[eidx] = Half2Float(out_data_from_gpu_half.data()[eidx]); + } + + // check output dims + for (int i = 0; i < output.dims().size(); i++) { + CHECK_EQ(output.dims()[i], shape_tensor_data[i]); + } + + // check output data + for (int i = 0; i < output.numel(); i++) { + auto out_gpu_data = out_data_from_gpu[i]; + if (fp16_flag) { + out_gpu_data = output_half2float[i]; + } + auto abs_diff = abs(out_gpu_data - x_source[i]); + auto relative_diff = COMPUTE_RELATIVE_DIFF(out_gpu_data, x_source[i]); + EXPECT_EQ((relative_diff <= FP16_MAX_DIFF) || (abs_diff <= FP16_MAX_DIFF), + true); + if ((relative_diff > FP16_MAX_DIFF) && (abs_diff > FP16_MAX_DIFF)) { + LOG(ERROR) << "error idx:" << i << " out_gpu_data[" << i + << "]:" << out_gpu_data << " " + "input_data[" + << i << "]:" << x_source[i] << " abs_diff:" << abs_diff + << " relative_diff:" << relative_diff + << " FP16_MAX_DIFF:" << FP16_MAX_DIFF; + } + } +} + +} // namespace opencl +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(reshape, kOpenCL, kFP16, kNCHW, def); +USE_LITE_KERNEL(reshape2, kOpenCL, kFP16, kNCHW, def); From 9de5a2193d28c066e140365134f4649e0a89476f Mon Sep 17 00:00:00 2001 From: sprouteer Date: Thu, 7 Jul 2022 18:36:22 +0800 Subject: [PATCH 2/4] fix fc image2d bug test=develop --- lite/kernels/opencl/fc_image_compute.cc | 9 +++++++-- lite/kernels/opencl/reshape_buffer_compute.cc | 16 ++++++++-------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/lite/kernels/opencl/fc_image_compute.cc b/lite/kernels/opencl/fc_image_compute.cc index 0aa24da6643..25a2d669b8c 100644 --- a/lite/kernels/opencl/fc_image_compute.cc +++ b/lite/kernels/opencl/fc_image_compute.cc @@ -155,8 +155,13 @@ class FcImageCompute : public KernelLitedims().size(), 4UL); - m_ = x_dims.Slice(0, param.in_num_col_dims).production(); - k_ = x_dims.Slice(param.in_num_col_dims, x_dims.size()).production(); + int in_num_col_dims = param.in_num_col_dims; + std::string op_type = param.op_type; + if (op_type == "matmul" || op_type == "matmul_v2") { + in_num_col_dims = x_dims.size() - 1; + } + m_ = x_dims.Slice(0, in_num_col_dims).production(); + k_ = x_dims.Slice(in_num_col_dims, x_dims.size()).production(); n_ = w_dims[1]; CHECK_EQ(k_, static_cast(w_dims[0])); k_blks_ = UP_DIV(k_, 4); diff --git a/lite/kernels/opencl/reshape_buffer_compute.cc b/lite/kernels/opencl/reshape_buffer_compute.cc index a45e606a92d..a51f8b30bae 100644 --- a/lite/kernels/opencl/reshape_buffer_compute.cc +++ b/lite/kernels/opencl/reshape_buffer_compute.cc @@ -33,7 +33,7 @@ namespace kernels { namespace opencl { class ReshapeComputeFloatBuffer - : public KernelLite { + : public KernelLite { public: using param_t = operators::ReshapeParam; @@ -42,7 +42,7 @@ class ReshapeComputeFloatBuffer void Run() override { auto& param = *param_.get_mutable(); const Tensor* const x = param.x; - Tensor* const output = param.output; + Tensor* output = param.output; auto output_dims = output->dims(); auto output_lod = output->lod(); @@ -71,13 +71,13 @@ class ReshapeComputeFloatBuffer REGISTER_LITE_KERNEL(reshape, kOpenCL, - kFP16, + kFloat, kNCHW, paddle::lite::kernels::opencl::ReshapeComputeFloatBuffer, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kOpenCL), - PRECISION(kFP16), + PRECISION(kAny), DATALAYOUT(kNCHW))}) .BindInput("ShapeTensor", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) @@ -85,19 +85,19 @@ REGISTER_LITE_KERNEL(reshape, {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kOpenCL), - PRECISION(kFP16), + PRECISION(kAny), DATALAYOUT(kNCHW))}) .Finalize(); REGISTER_LITE_KERNEL(reshape2, kOpenCL, - kFP16, + kFloat, kNCHW, paddle::lite::kernels::opencl::ReshapeComputeFloatBuffer, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kOpenCL), - PRECISION(kFP16), + PRECISION(kAny), DATALAYOUT(kNCHW))}) .BindInput("ShapeTensor", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) @@ -107,7 +107,7 @@ REGISTER_LITE_KERNEL(reshape2, {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kOpenCL), - PRECISION(kFP16), + PRECISION(kAny), DATALAYOUT(kNCHW))}) .Finalize(); From 30e52580cf9e44e840e5122d888b7c18b8eb48d2 Mon Sep 17 00:00:00 2001 From: sprouteer Date: Thu, 7 Jul 2022 19:30:32 +0800 Subject: [PATCH 3/4] fix fc image2d bug test=develop --- lite/kernels/opencl/reshape_buffer_compute_test.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lite/kernels/opencl/reshape_buffer_compute_test.cc b/lite/kernels/opencl/reshape_buffer_compute_test.cc index a6b93ce77e3..9afed17f450 100644 --- a/lite/kernels/opencl/reshape_buffer_compute_test.cc +++ b/lite/kernels/opencl/reshape_buffer_compute_test.cc @@ -84,7 +84,7 @@ static DDim ValidateShape(const std::vector& shape, TEST(reshape_opencl, compute) { LOG(INFO) << "to get kernel ..."; auto kernels = KernelRegistry::Global().Create( - "reshape", TARGET(kOpenCL), PRECISION(kFP16), DATALAYOUT(kNCHW)); + "reshape", TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNCHW)); ASSERT_FALSE(kernels.empty()); auto kernel = std::move(kernels.front()); lite_api::CLPrecisionType p = lite_api::CLPrecisionType::CL_PRECISION_FP16; @@ -217,5 +217,5 @@ TEST(reshape_opencl, compute) { } // namespace lite } // namespace paddle -USE_LITE_KERNEL(reshape, kOpenCL, kFP16, kNCHW, def); -USE_LITE_KERNEL(reshape2, kOpenCL, kFP16, kNCHW, def); +USE_LITE_KERNEL(reshape, kOpenCL, kFloat, kNCHW, def); +USE_LITE_KERNEL(reshape2, kOpenCL, kFloat, kNCHW, def); From 96d4c26924d16714c3c1a4ec05b7ea6a30a49c6f Mon Sep 17 00:00:00 2001 From: sprouteer Date: Fri, 22 Jul 2022 11:32:52 +0800 Subject: [PATCH 4/4] fix opencl reshape op_test test=develop --- lite/kernels/opencl/CMakeLists.txt | 3 - .../opencl/reshape_buffer_compute_test.cc | 221 ------------------ lite/tests/kernels/reshape_compute_test.cc | 3 + 3 files changed, 3 insertions(+), 224 deletions(-) delete mode 100644 lite/kernels/opencl/reshape_buffer_compute_test.cc diff --git a/lite/kernels/opencl/CMakeLists.txt b/lite/kernels/opencl/CMakeLists.txt index 6e799d9f0d9..3c0b7ce1359 100644 --- a/lite/kernels/opencl/CMakeLists.txt +++ b/lite/kernels/opencl/CMakeLists.txt @@ -216,6 +216,3 @@ lite_cc_test(test_fc_buffer_opencl SRCS fc_buffer_compute_test.cc lite_cc_test(test_io_copy_buffer_opencl SRCS io_copy_buffer_compute_test.cc DEPS kernels core) - -lite_cc_test(test_reshape_buffer_opencl SRCS reshape_buffer_compute_test.cc - DEPS kernels core) diff --git a/lite/kernels/opencl/reshape_buffer_compute_test.cc b/lite/kernels/opencl/reshape_buffer_compute_test.cc deleted file mode 100644 index 9afed17f450..00000000000 --- a/lite/kernels/opencl/reshape_buffer_compute_test.cc +++ /dev/null @@ -1,221 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include "lite/backends/opencl/target_wrapper.h" -#include "lite/core/op_registry.h" -#include "lite/core/tensor.h" -#include "lite/kernels/opencl/test_helper.h" -#include "lite/operators/reshape_op.h" -#include "lite/utils/log/logging.h" - -#define FP16_MAX_DIFF (5e-1) - -namespace paddle { -namespace lite { -namespace kernels { -namespace opencl { -static DDim ValidateShape(const std::vector& shape, - const DDim& input_dims) { - const lite::DDim::value_type input_size = input_dims.production(); - auto input_shape = input_dims.Vectorize(); - bool all_positive = std::all_of( - input_shape.cbegin(), input_shape.cend(), [](lite::DDim::value_type i) { - return i > 0; - }); - // only one dimension can be set to -1, whose size will be automatically - // infered. - const int unk_dim_val = -1; - const int copy_dim_val = 0; - - std::vector output_shape(shape.size(), 0); - lite::DDim::value_type capacity = 1; - int unk_dim_idx = -1; - for (size_t i = 0; i < shape.size(); ++i) { - if (shape[i] == unk_dim_val) { - CHECK_EQ(unk_dim_idx, -1) - << "Only one input dimension of Attr(shape) can be unknown."; - unk_dim_idx = i; - } else if (shape[i] == copy_dim_val) { - CHECK_LT(static_cast(i), input_shape.size()) - << "The index of dimension to copy from input shape must be less " - "than the size of input shape."; - } else { - CHECK_GT(shape[i], 0) << "Each input dimension of Attr(shape) must not " - "be negtive except one unknown dimension."; - } - - capacity *= (shape[i] ? static_cast(shape[i]) - : input_shape[i]); - output_shape[i] = (shape[i] ? static_cast(shape[i]) - : input_shape[i]); - } - - if (unk_dim_idx != -1) { - if (all_positive) { - // input_size < 0 and is un-determinate in compile time, skip the check, - // for example, input_dims = [-1, 8, 1, 1], shape = [-1, 3, 8], - // capacity = -24, input_size = -8, output_shape[0] = 0 - // the following check will fail. - output_shape[unk_dim_idx] = -input_size / capacity; - CHECK_EQ(output_shape[unk_dim_idx] * capacity, -input_size) - << "Invalid shape is given."; - } else { - output_shape[unk_dim_idx] = -1; - } - } else { - CHECK_EQ(capacity, input_size) << "Invalid shape is given."; - } - return lite::DDim(output_shape); -} - -TEST(reshape_opencl, compute) { - LOG(INFO) << "to get kernel ..."; - auto kernels = KernelRegistry::Global().Create( - "reshape", TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNCHW)); - ASSERT_FALSE(kernels.empty()); - auto kernel = std::move(kernels.front()); - lite_api::CLPrecisionType p = lite_api::CLPrecisionType::CL_PRECISION_FP16; - CLRuntime::Global()->set_precision(p); - const bool fp16_flag = (p == lite_api::CLPrecisionType::CL_PRECISION_FP16); - LOG(INFO) << "created reshape kernel"; - - LOG(INFO) << "prepare kernel ------"; - - int64_t batch_size = 15; - int64_t ic = 1; - int64_t ih = 2; - int64_t iw = 3; - - lite::Tensor input, output, input_h; - - operators::ReshapeParam param; - - Tensor shape_tensor; - shape_tensor.Resize({3}); - auto* shape_tensor_data = shape_tensor.mutable_data(); - shape_tensor_data[0] = 1; - shape_tensor_data[1] = 15; - shape_tensor_data[2] = 6; - - if (fp16_flag) { - param.x = &input_h; - param.shape_tensor = &shape_tensor; // use shape_tensor - param.inplace = true; - param.output = &output; - } else { - param.x = &input; - param.shape_tensor = &shape_tensor; // use shape_tensor - param.inplace = true; - param.output = &output; - } - - const DDim input_dim = - lite::DDim{std::vector({batch_size, ic, ih, iw})}; - input.Resize(input_dim); - input_h.Resize(input_dim); - - std::vector final_shape = std::vector( - shape_tensor_data, shape_tensor_data + shape_tensor.numel()); - LOG(INFO) << "shape_tensor.numel() " << shape_tensor.numel(); - auto out_dim = ValidateShape(final_shape, input_dim); - param.output->Resize(out_dim); - LOG(INFO) << " out_dim------" << out_dim; - - LOG(INFO) << "prepare kernel SetParam------"; - kernel->SetParam(param); - std::unique_ptr context(new KernelContext); - context->As().InitOnce(); - kernel->SetContext(std::move(context)); - - auto* input_data_h = - input_h.mutable_data(TARGET(kOpenCL)); - auto* input_data = input.mutable_data(TARGET(kOpenCL)); - - std::default_random_engine engine; - std::uniform_real_distribution dist(-5, 5); - LOG(INFO) << "gen input ..."; - std::vector x_source(input_dim.production()); - std::vector x_source_half(input_dim.production()); - for (size_t i = 0; i < input_dim.production(); ++i) { - x_source[i] = static_cast(dist(engine)); - x_source_half[i] = Float2Half(x_source[i]); - } - - size_t x_size = input_dim.production() * sizeof(float); - if (fp16_flag) { - x_size = input_dim.production() * sizeof(half_t); - TargetWrapperCL::MemcpySync( - input_data_h, x_source_half.data(), x_size, IoDirection::HtoD); - } else { - TargetWrapperCL::MemcpySync( - input_data, x_source.data(), x_size, IoDirection::HtoD); - } - - kernel->Launch(); - CLRuntime::Global()->command_queue().finish(); - auto* y_buffer = fp16_flag ? output.data() - : output.data(); - std::vector out_data_from_gpu(out_dim.production()); - std::vector output_half2float(out_dim.production()); - std::vector out_data_from_gpu_half(out_dim.production()); - if (fp16_flag) { - TargetWrapperCL::MemcpySync(out_data_from_gpu_half.data(), - y_buffer, - out_data_from_gpu_half.size() * sizeof(half_t), - IoDirection::DtoH); - } else { - TargetWrapperCL::MemcpySync(out_data_from_gpu.data(), - y_buffer, - out_data_from_gpu.size() * sizeof(float), - IoDirection::DtoH); - } - for (int eidx = 0; eidx < out_dim.production(); ++eidx) { - output_half2float[eidx] = Half2Float(out_data_from_gpu_half.data()[eidx]); - } - - // check output dims - for (int i = 0; i < output.dims().size(); i++) { - CHECK_EQ(output.dims()[i], shape_tensor_data[i]); - } - - // check output data - for (int i = 0; i < output.numel(); i++) { - auto out_gpu_data = out_data_from_gpu[i]; - if (fp16_flag) { - out_gpu_data = output_half2float[i]; - } - auto abs_diff = abs(out_gpu_data - x_source[i]); - auto relative_diff = COMPUTE_RELATIVE_DIFF(out_gpu_data, x_source[i]); - EXPECT_EQ((relative_diff <= FP16_MAX_DIFF) || (abs_diff <= FP16_MAX_DIFF), - true); - if ((relative_diff > FP16_MAX_DIFF) && (abs_diff > FP16_MAX_DIFF)) { - LOG(ERROR) << "error idx:" << i << " out_gpu_data[" << i - << "]:" << out_gpu_data << " " - "input_data[" - << i << "]:" << x_source[i] << " abs_diff:" << abs_diff - << " relative_diff:" << relative_diff - << " FP16_MAX_DIFF:" << FP16_MAX_DIFF; - } - } -} - -} // namespace opencl -} // namespace kernels -} // namespace lite -} // namespace paddle - -USE_LITE_KERNEL(reshape, kOpenCL, kFloat, kNCHW, def); -USE_LITE_KERNEL(reshape2, kOpenCL, kFloat, kNCHW, def); diff --git a/lite/tests/kernels/reshape_compute_test.cc b/lite/tests/kernels/reshape_compute_test.cc index b9099d1c8cd..f2059b8ad9b 100644 --- a/lite/tests/kernels/reshape_compute_test.cc +++ b/lite/tests/kernels/reshape_compute_test.cc @@ -217,6 +217,9 @@ TEST(Reshape, precision) { #elif defined(LITE_WITH_NPU) place = TARGET(kNPU); abs_error = 1e-2; // Using fp16 in NPU +#elif defined(LITE_WITH_OPENCL) + place = Place(TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNCHW)); + abs_error = 1e-2; // Using fp16 in OPENCL #elif defined(LITE_WITH_ARM) || defined(LITE_WITH_X86) place = TARGET(kHost); #else