From 37c396121e4de6711925f0eded0d8648957aa43d Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 <49090790+xiaoxiaohehe001@users.noreply.github.com> Date: Tue, 18 Jan 2022 16:35:58 +0800 Subject: [PATCH] [Metal] fix bilinear_interp bug (#8285) * fix_bilinear_interp * Update test_bilinear_interp_op.py --- .../metal_kernel/texture/BilinearInterp.metal | 13 ++++- .../metal/image_op/interp_image_compute.mm | 24 +++++----- .../unittest_py/op/test_bilinear_interp_op.py | 45 +++++++++++------- .../op/test_bilinear_interp_v2_op.py | 47 ++++++++++++------- 4 files changed, 80 insertions(+), 49 deletions(-) diff --git a/lite/backends/metal/metal_kernel/texture/BilinearInterp.metal b/lite/backends/metal/metal_kernel/texture/BilinearInterp.metal index e8e35270130..4477f8add0a 100644 --- a/lite/backends/metal/metal_kernel/texture/BilinearInterp.metal +++ b/lite/backends/metal/metal_kernel/texture/BilinearInterp.metal @@ -32,8 +32,17 @@ kernel void bilinear_interp(texture2d_array input[[texture( } else { ftype w = (gid.x + pm.align_delta) * pm.ratio_w - pm.align_delta; ftype h = (gid.y + pm.align_delta) * pm.ratio_h - pm.align_delta; - uint w0 = w, h0 = h; - uint w1 = w0 + 1, h1 = h0 + 1; + h = (h > 0) ? h : 0; + w = (w > 0) ? w : 0; + int w0 = (int)w; + int h0 = (int)h; + int w1 = w0 + 1, h1 = h0 + 1; + if (w0 < 0) { + w0 = 0; + } + if (h0 < 0) { + h0 = 0; + } ftype w1lambda = w - w0, h1lambda = h - h0; ftype w2lambda = 1.0 - w1lambda, h2lambda = 1.0 - h1lambda; diff --git a/lite/kernels/metal/image_op/interp_image_compute.mm b/lite/kernels/metal/image_op/interp_image_compute.mm index ca88abc1de4..3b45996ee63 100644 --- a/lite/kernels/metal/image_op/interp_image_compute.mm +++ b/lite/kernels/metal/image_op/interp_image_compute.mm @@ -185,11 +185,11 @@ PRECISION(kFloat), DATALAYOUT(kMetalTexture2DArray))}) .BindInput("OutSize", - {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW))}) + {LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kInt32), DATALAYOUT(kNCHW))}) .BindInput("SizeTensor", - {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW))}) + {LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kInt32), DATALAYOUT(kNCHW))}) .BindInput("Scale", - {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW))}) + {LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kInt32), DATALAYOUT(kNCHW))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kFloat), @@ -205,11 +205,11 @@ .BindInput("X", {LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kFP16), DATALAYOUT(kMetalTexture2DArray))}) .BindInput("OutSize", - {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW))}) + {LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kInt32), DATALAYOUT(kNCHW))}) .BindInput("SizeTensor", - {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW))}) + {LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kInt32), DATALAYOUT(kNCHW))}) .BindInput("Scale", - {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW))}) + {LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kInt32), DATALAYOUT(kNCHW))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kFP16), DATALAYOUT(kMetalTexture2DArray))}) .Finalize(); @@ -225,11 +225,11 @@ PRECISION(kFloat), DATALAYOUT(kMetalTexture2DArray))}) .BindInput("OutSize", - {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW))}) + {LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kInt32), DATALAYOUT(kNCHW))}) .BindInput("SizeTensor", - {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW))}) + {LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kInt32), DATALAYOUT(kNCHW))}) .BindInput("Scale", - {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW))}) + {LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kInt32), DATALAYOUT(kNCHW))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kFloat), @@ -245,11 +245,11 @@ .BindInput("X", {LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kFP16), DATALAYOUT(kMetalTexture2DArray))}) .BindInput("OutSize", - {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW))}) + {LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kInt32), DATALAYOUT(kNCHW))}) .BindInput("SizeTensor", - {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW))}) + {LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kInt32), DATALAYOUT(kNCHW))}) .BindInput("Scale", - {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW))}) + {LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kInt32), DATALAYOUT(kNCHW))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kFP16), DATALAYOUT(kMetalTexture2DArray))}) .Finalize(); diff --git a/lite/tests/unittest_py/op/test_bilinear_interp_op.py b/lite/tests/unittest_py/op/test_bilinear_interp_op.py index 714bf95b15b..59f7cd5683c 100644 --- a/lite/tests/unittest_py/op/test_bilinear_interp_op.py +++ b/lite/tests/unittest_py/op/test_bilinear_interp_op.py @@ -40,22 +40,29 @@ def __init__(self, *args, **kwargs): PrecisionType.FP32, DataLayoutType.NCHW, thread=[1, 4]) - # opencl demo - # opencl has diff - # opencl_places = [ - # Place(TargetType.OpenCL, PrecisionType.FP16, - # DataLayoutType.ImageDefault), Place( - # TargetType.OpenCL, PrecisionType.FP16, - # DataLayoutType.ImageFolder), - # Place(TargetType.OpenCL, PrecisionType.FP32, DataLayoutType.NCHW), - # Place(TargetType.OpenCL, PrecisionType.Any, - # DataLayoutType.ImageDefault), Place( - # TargetType.OpenCL, PrecisionType.Any, - # DataLayoutType.ImageFolder), - # Place(TargetType.OpenCL, PrecisionType.Any, DataLayoutType.NCHW), - # Place(TargetType.Host, PrecisionType.FP32) - # ] - # self.enable_testing_on_place(places=opencl_places) + opencl_places = [ + Place(TargetType.OpenCL, PrecisionType.FP16, + DataLayoutType.ImageDefault), Place( + TargetType.OpenCL, PrecisionType.FP16, + DataLayoutType.ImageFolder), + Place(TargetType.OpenCL, PrecisionType.FP32, DataLayoutType.NCHW), + Place(TargetType.OpenCL, PrecisionType.Any, + DataLayoutType.ImageDefault), Place( + TargetType.OpenCL, PrecisionType.Any, + DataLayoutType.ImageFolder), + Place(TargetType.OpenCL, PrecisionType.Any, DataLayoutType.NCHW), + Place(TargetType.Host, PrecisionType.FP32) + ] + self.enable_testing_on_place(places=opencl_places) + metal_places = [ + Place(TargetType.Metal, PrecisionType.FP32, + DataLayoutType.MetalTexture2DArray), + Place(TargetType.Metal, PrecisionType.FP16, + DataLayoutType.MetalTexture2DArray), + Place(TargetType.ARM, PrecisionType.FP32), + Place(TargetType.Host, PrecisionType.FP32) + ] + self.enable_testing_on_place(places=metal_places) def is_program_valid(self, program_config: ProgramConfig, @@ -123,7 +130,11 @@ def generate_scale(*args, **kwargs): return program_config def sample_predictor_configs(self): - return self.get_predictor_configs(), ["bilinear_interp"], (1e-4, 1e-4) + atol, rtol = 1e-4, 1e-4 + target_str = self.get_target() + if target_str == "Metal": + atol, rtol = 5e-1, 5e-1 + return self.get_predictor_configs(), ["bilinear_interp"], (atol, rtol) def add_ignore_pass_case(self): pass diff --git a/lite/tests/unittest_py/op/test_bilinear_interp_v2_op.py b/lite/tests/unittest_py/op/test_bilinear_interp_v2_op.py index b28ba4b85c5..0ca343d3ff9 100644 --- a/lite/tests/unittest_py/op/test_bilinear_interp_v2_op.py +++ b/lite/tests/unittest_py/op/test_bilinear_interp_v2_op.py @@ -40,22 +40,29 @@ def __init__(self, *args, **kwargs): PrecisionType.FP32, DataLayoutType.NCHW, thread=[1, 4]) - # opencl demo - # opencl has diff - # opencl_places = [ - # Place(TargetType.OpenCL, PrecisionType.FP16, - # DataLayoutType.ImageDefault), Place( - # TargetType.OpenCL, PrecisionType.FP16, - # DataLayoutType.ImageFolder), - # Place(TargetType.OpenCL, PrecisionType.FP32, DataLayoutType.NCHW), - # Place(TargetType.OpenCL, PrecisionType.Any, - # DataLayoutType.ImageDefault), Place( - # TargetType.OpenCL, PrecisionType.Any, - # DataLayoutType.ImageFolder), - # Place(TargetType.OpenCL, PrecisionType.Any, DataLayoutType.NCHW), - # Place(TargetType.Host, PrecisionType.FP32) - # ] - # self.enable_testing_on_place(places=opencl_places) + opencl_places = [ + Place(TargetType.OpenCL, PrecisionType.FP16, + DataLayoutType.ImageDefault), Place( + TargetType.OpenCL, PrecisionType.FP16, + DataLayoutType.ImageFolder), + Place(TargetType.OpenCL, PrecisionType.FP32, DataLayoutType.NCHW), + Place(TargetType.OpenCL, PrecisionType.Any, + DataLayoutType.ImageDefault), Place( + TargetType.OpenCL, PrecisionType.Any, + DataLayoutType.ImageFolder), + Place(TargetType.OpenCL, PrecisionType.Any, DataLayoutType.NCHW), + Place(TargetType.Host, PrecisionType.FP32) + ] + self.enable_testing_on_place(places=opencl_places) + metal_places = [ + Place(TargetType.Metal, PrecisionType.FP32, + DataLayoutType.MetalTexture2DArray), + Place(TargetType.Metal, PrecisionType.FP16, + DataLayoutType.MetalTexture2DArray), + Place(TargetType.ARM, PrecisionType.FP32), + Place(TargetType.Host, PrecisionType.FP32) + ] + self.enable_testing_on_place(places=metal_places) def is_program_valid(self, program_config: ProgramConfig, @@ -128,8 +135,12 @@ def generate_scale(*args, **kwargs): return program_config def sample_predictor_configs(self): - return self.get_predictor_configs(), ["bilinear_interp_v2"], (1e-4, - 1e-4) + atol, rtol = 1e-4, 1e-4 + target_str = self.get_target() + if target_str == "Metal": + atol, rtol = 3e-1, 3e-1 + return self.get_predictor_configs(), ["bilinear_interp_v2"], (atol, + rtol) def add_ignore_pass_case(self): pass