Skip to content

Commit

Permalink
[Metal] fix bilinear_interp bug (#8285)
Browse files Browse the repository at this point in the history
* fix_bilinear_interp

* Update test_bilinear_interp_op.py
  • Loading branch information
xiaoxiaohehe001 authored Jan 18, 2022
1 parent 5f78558 commit 37c3961
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 49 deletions.
13 changes: 11 additions & 2 deletions lite/backends/metal/metal_kernel/texture/BilinearInterp.metal
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,17 @@ kernel void bilinear_interp(texture2d_array<ftype, access::read> input[[texture(
} else {
ftype w = (gid.x + pm.align_delta) * pm.ratio_w - pm.align_delta;
ftype h = (gid.y + pm.align_delta) * pm.ratio_h - pm.align_delta;
uint w0 = w, h0 = h;
uint w1 = w0 + 1, h1 = h0 + 1;
h = (h > 0) ? h : 0;
w = (w > 0) ? w : 0;
int w0 = (int)w;
int h0 = (int)h;
int w1 = w0 + 1, h1 = h0 + 1;
if (w0 < 0) {
w0 = 0;
}
if (h0 < 0) {
h0 = 0;
}

ftype w1lambda = w - w0, h1lambda = h - h0;
ftype w2lambda = 1.0 - w1lambda, h2lambda = 1.0 - h1lambda;
Expand Down
24 changes: 12 additions & 12 deletions lite/kernels/metal/image_op/interp_image_compute.mm
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,11 @@
PRECISION(kFloat),
DATALAYOUT(kMetalTexture2DArray))})
.BindInput("OutSize",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW))})
{LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kInt32), DATALAYOUT(kNCHW))})
.BindInput("SizeTensor",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW))})
{LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kInt32), DATALAYOUT(kNCHW))})
.BindInput("Scale",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW))})
{LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kInt32), DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kMetal),
PRECISION(kFloat),
Expand All @@ -205,11 +205,11 @@
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kFP16), DATALAYOUT(kMetalTexture2DArray))})
.BindInput("OutSize",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW))})
{LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kInt32), DATALAYOUT(kNCHW))})
.BindInput("SizeTensor",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW))})
{LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kInt32), DATALAYOUT(kNCHW))})
.BindInput("Scale",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW))})
{LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kInt32), DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kFP16), DATALAYOUT(kMetalTexture2DArray))})
.Finalize();
Expand All @@ -225,11 +225,11 @@
PRECISION(kFloat),
DATALAYOUT(kMetalTexture2DArray))})
.BindInput("OutSize",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW))})
{LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kInt32), DATALAYOUT(kNCHW))})
.BindInput("SizeTensor",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW))})
{LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kInt32), DATALAYOUT(kNCHW))})
.BindInput("Scale",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW))})
{LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kInt32), DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kMetal),
PRECISION(kFloat),
Expand All @@ -245,11 +245,11 @@
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kFP16), DATALAYOUT(kMetalTexture2DArray))})
.BindInput("OutSize",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW))})
{LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kInt32), DATALAYOUT(kNCHW))})
.BindInput("SizeTensor",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW))})
{LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kInt32), DATALAYOUT(kNCHW))})
.BindInput("Scale",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW))})
{LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kInt32), DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kFP16), DATALAYOUT(kMetalTexture2DArray))})
.Finalize();
Expand Down
45 changes: 28 additions & 17 deletions lite/tests/unittest_py/op/test_bilinear_interp_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,29 @@ def __init__(self, *args, **kwargs):
PrecisionType.FP32,
DataLayoutType.NCHW,
thread=[1, 4])
# opencl demo
# opencl has diff
# opencl_places = [
# Place(TargetType.OpenCL, PrecisionType.FP16,
# DataLayoutType.ImageDefault), Place(
# TargetType.OpenCL, PrecisionType.FP16,
# DataLayoutType.ImageFolder),
# Place(TargetType.OpenCL, PrecisionType.FP32, DataLayoutType.NCHW),
# Place(TargetType.OpenCL, PrecisionType.Any,
# DataLayoutType.ImageDefault), Place(
# TargetType.OpenCL, PrecisionType.Any,
# DataLayoutType.ImageFolder),
# Place(TargetType.OpenCL, PrecisionType.Any, DataLayoutType.NCHW),
# Place(TargetType.Host, PrecisionType.FP32)
# ]
# self.enable_testing_on_place(places=opencl_places)
opencl_places = [
Place(TargetType.OpenCL, PrecisionType.FP16,
DataLayoutType.ImageDefault), Place(
TargetType.OpenCL, PrecisionType.FP16,
DataLayoutType.ImageFolder),
Place(TargetType.OpenCL, PrecisionType.FP32, DataLayoutType.NCHW),
Place(TargetType.OpenCL, PrecisionType.Any,
DataLayoutType.ImageDefault), Place(
TargetType.OpenCL, PrecisionType.Any,
DataLayoutType.ImageFolder),
Place(TargetType.OpenCL, PrecisionType.Any, DataLayoutType.NCHW),
Place(TargetType.Host, PrecisionType.FP32)
]
self.enable_testing_on_place(places=opencl_places)
metal_places = [
Place(TargetType.Metal, PrecisionType.FP32,
DataLayoutType.MetalTexture2DArray),
Place(TargetType.Metal, PrecisionType.FP16,
DataLayoutType.MetalTexture2DArray),
Place(TargetType.ARM, PrecisionType.FP32),
Place(TargetType.Host, PrecisionType.FP32)
]
self.enable_testing_on_place(places=metal_places)

def is_program_valid(self,
program_config: ProgramConfig,
Expand Down Expand Up @@ -123,7 +130,11 @@ def generate_scale(*args, **kwargs):
return program_config

def sample_predictor_configs(self):
return self.get_predictor_configs(), ["bilinear_interp"], (1e-4, 1e-4)
atol, rtol = 1e-4, 1e-4
target_str = self.get_target()
if target_str == "Metal":
atol, rtol = 5e-1, 5e-1
return self.get_predictor_configs(), ["bilinear_interp"], (atol, rtol)

def add_ignore_pass_case(self):
pass
Expand Down
47 changes: 29 additions & 18 deletions lite/tests/unittest_py/op/test_bilinear_interp_v2_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,29 @@ def __init__(self, *args, **kwargs):
PrecisionType.FP32,
DataLayoutType.NCHW,
thread=[1, 4])
# opencl demo
# opencl has diff
# opencl_places = [
# Place(TargetType.OpenCL, PrecisionType.FP16,
# DataLayoutType.ImageDefault), Place(
# TargetType.OpenCL, PrecisionType.FP16,
# DataLayoutType.ImageFolder),
# Place(TargetType.OpenCL, PrecisionType.FP32, DataLayoutType.NCHW),
# Place(TargetType.OpenCL, PrecisionType.Any,
# DataLayoutType.ImageDefault), Place(
# TargetType.OpenCL, PrecisionType.Any,
# DataLayoutType.ImageFolder),
# Place(TargetType.OpenCL, PrecisionType.Any, DataLayoutType.NCHW),
# Place(TargetType.Host, PrecisionType.FP32)
# ]
# self.enable_testing_on_place(places=opencl_places)
opencl_places = [
Place(TargetType.OpenCL, PrecisionType.FP16,
DataLayoutType.ImageDefault), Place(
TargetType.OpenCL, PrecisionType.FP16,
DataLayoutType.ImageFolder),
Place(TargetType.OpenCL, PrecisionType.FP32, DataLayoutType.NCHW),
Place(TargetType.OpenCL, PrecisionType.Any,
DataLayoutType.ImageDefault), Place(
TargetType.OpenCL, PrecisionType.Any,
DataLayoutType.ImageFolder),
Place(TargetType.OpenCL, PrecisionType.Any, DataLayoutType.NCHW),
Place(TargetType.Host, PrecisionType.FP32)
]
self.enable_testing_on_place(places=opencl_places)
metal_places = [
Place(TargetType.Metal, PrecisionType.FP32,
DataLayoutType.MetalTexture2DArray),
Place(TargetType.Metal, PrecisionType.FP16,
DataLayoutType.MetalTexture2DArray),
Place(TargetType.ARM, PrecisionType.FP32),
Place(TargetType.Host, PrecisionType.FP32)
]
self.enable_testing_on_place(places=metal_places)

def is_program_valid(self,
program_config: ProgramConfig,
Expand Down Expand Up @@ -128,8 +135,12 @@ def generate_scale(*args, **kwargs):
return program_config

def sample_predictor_configs(self):
return self.get_predictor_configs(), ["bilinear_interp_v2"], (1e-4,
1e-4)
atol, rtol = 1e-4, 1e-4
target_str = self.get_target()
if target_str == "Metal":
atol, rtol = 3e-1, 3e-1
return self.get_predictor_configs(), ["bilinear_interp_v2"], (atol,
rtol)

def add_ignore_pass_case(self):
pass
Expand Down

0 comments on commit 37c3961

Please sign in to comment.