From c28ec2eb9c7a10741820c71e33605ace50f50df4 Mon Sep 17 00:00:00 2001 From: Zhang Jun Date: Mon, 29 Nov 2021 14:51:27 +0800 Subject: [PATCH] fix_concat (#7465) (#7747) Co-authored-by: xiaoxiaohehe001 <49090790+xiaoxiaohehe001@users.noreply.github.com> --- .../texture/ConcatKernel.inc.metal | 298 ---------- .../metal_kernel/texture/ConcatKernel.metal | 509 ++++++------------ .../metal/image_op/concat_image_compute.mm | 28 +- lite/kernels/metal/image_op/metal_params.h | 2 + 4 files changed, 191 insertions(+), 646 deletions(-) delete mode 100644 lite/backends/metal/metal_kernel/texture/ConcatKernel.inc.metal diff --git a/lite/backends/metal/metal_kernel/texture/ConcatKernel.inc.metal b/lite/backends/metal/metal_kernel/texture/ConcatKernel.inc.metal deleted file mode 100644 index db7a0a0c360..00000000000 --- a/lite/backends/metal/metal_kernel/texture/ConcatKernel.inc.metal +++ /dev/null @@ -1,298 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#ifdef P - -#define CONCAT2(a, b) a##b -#define CONCAT4_(a, b, c, d) a##_##b##_##c##_##d - -#define FUNC(f, r, n, v) CONCAT4_(f, r, n, v) -#define VECTOR(p, n) CONCAT2(p, n) - -#if V == VNORMAL -kernel void FUNC(concat, R, N, normal)(texture2d_array in0[[texture(0)]], - texture2d_array in1[[texture(1)]], -#if N >= 3 - texture2d_array in2[[texture(2)]], -#endif -#if N >= 4 - texture2d_array in3[[texture(3)]], -#endif -#if N >= 5 - texture2d_array in4[[texture(4)]], -#endif -#if N >= 6 - texture2d_array in5[[texture(5)]], -#endif - texture2d_array inx[[texture(N)]], - texture2d_array out[[texture(N + 1)]], - constant ConcatParam& pm[[buffer(0)]], - uint3 gid[[thread_position_in_grid]]) { - - ConcatParam cp = pm; - int xyzn[4] = {int(gid.x), int(gid.y), int(gid.z), 0}, abcd[4], oxyzn[4]; - VECTOR(P, 4) r = inx.read(gid.xy, gid.z); - for (int i = 0; i < 4; i++) { - xyzn[3] = i; - xyzn2abcd_4(cp.odim[3], xyzn, abcd); - int k = abcd[cp.axis] - cp.offset; - if (k < 0) continue; - int j = 0; - for (; j < N; j++) { - if (k < cp.vdim[j]) { - break; - } - k -= cp.vdim[j]; - } - if (j == N) { - continue; - } - int ta = cp.odim[cp.axis]; - abcd[cp.axis] = k; - cp.odim[cp.axis] = cp.vdim[j]; - abcd2xyzn_4(cp.odim[3], abcd, oxyzn); - cp.odim[cp.axis] = ta; - switch (j) { - case 0: - r[i] = in0.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; - break; - case 1: - r[i] = in1.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; - break; -#if N >= 3 - case 2: - r[i] = in2.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; - break; -#endif -#if N >= 4 - case 3: - r[i] = in3.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; - break; -#endif -#if N >= 5 - case 4: - r[i] = in4.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; - break; -#endif -#if N >= 6 - case 5: - r[i] = in5.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; - break; -#endif - } - } - out.write(r, gid.xy, gid.z); -} - -#endif // V == NORMAL - -#if V == VX -kernel void FUNC(concat, R, N, x)(texture2d_array in0[[texture(0)]], - texture2d_array in1[[texture(1)]], -#if N >= 3 - texture2d_array in2[[texture(2)]], -#endif // N >= 3 -#if N >= 4 - texture2d_array in3[[texture(3)]], -#endif // N >= 4 -#if N >= 5 - texture2d_array in4[[texture(4)]], -#endif // N >= 5 -#if N >= 6 - texture2d_array in5[[texture(5)]], -#endif // N >= 6 - texture2d_array out[[texture(N)]], - constant ConcatParam& pm[[buffer(0)]], - uint3 gid[[thread_position_in_grid]]) { - int x = gid.x - pm.offset; - if (x < 0) return; - if (x < pm.vdim[0]) { - VECTOR(P, 4) r = in0.read(gid.xy, gid.z); - out.write(r, gid.xy, gid.z); - return; - } - x -= pm.vdim[0]; - if (x < pm.vdim[1]) { - VECTOR(P, 4) r = in1.read(uint2(x, gid.y), gid.z); - out.write(r, gid.xy, gid.z); - return; - } -#if N >= 3 - x -= pm.vdim[1]; - if (x < pm.vdim[2]) { - VECTOR(P, 4) r = in2.read(uint2(x, gid.y), gid.z); - out.write(r, gid.xy, gid.z); - return; - } -#endif // N >= 3 -#if N >= 4 - x -= pm.vdim[2]; - if (x < pm.vdim[3]) { - VECTOR(P, 4) r = in3.read(uint2(x, gid.y), gid.z); - out.write(r, gid.xy, gid.z); - return; - } -#endif // N >= 4 -#if N >= 5 - x -= pm.vdim[3]; - if (x < pm.vdim[4]) { - VECTOR(P, 4) r = in4.read(uint2(x, gid.y), gid.z); - out.write(r, gid.xy, gid.z); - return; - } -#endif // N >= 5 -#if N >= 6 - x -= pm.vdim[4]; - if (x < pm.vdim[5]) { - VECTOR(P, 4) r = in5.read(uint2(x, gid.y), gid.z); - out.write(r, gid.xy, gid.z); - return; - } -#endif // N >= 6 -} -#endif // V == VX - -#if V == VY -kernel void FUNC(concat, R, N, y)(texture2d_array in0[[texture(0)]], - texture2d_array in1[[texture(1)]], -#if N >= 3 - texture2d_array in2[[texture(2)]], -#endif // N >= 3 -#if N >= 4 - texture2d_array in3[[texture(3)]], -#endif // N >= 4 -#if N >= 5 - texture2d_array in4[[texture(4)]], -#endif // N >= 5 -#if N >= 6 - texture2d_array in5[[texture(5)]], -#endif // N >= 6 - texture2d_array out[[texture(N)]], - constant ConcatParam& pm[[buffer(0)]], - uint3 gid[[thread_position_in_grid]]) { - int y = gid.y - pm.offset; - if (y < 0) return; - if (y < pm.vdim[0]) { - VECTOR(P, 4) r = in0.read(gid.xy, gid.z); - out.write(r, gid.xy, gid.z); - return; - } - y -= pm.vdim[0]; - if (y < pm.vdim[1]) { - VECTOR(P, 4) r = in1.read(uint2(gid.x, y), gid.z); - out.write(r, gid.xy, gid.z); - return; - } -#if N >= 3 - y -= pm.vdim[1]; - if (y < pm.vdim[2]) { - VECTOR(P, 4) r = in2.read(uint2(gid.x, y), gid.z); - out.write(r, gid.xy, gid.z); - return; - } -#endif // N >= 3 -#if N >= 4 - y -= pm.vdim[2]; - if (y < pm.vdim[3]) { - VECTOR(P, 4) r = in3.read(uint2(gid.x, y), gid.z); - out.write(r, gid.xy, gid.z); - return; - } -#endif // N >= 4 -#if N >= 5 - y -= pm.vdim[3]; - if (y < pm.vdim[4]) { - VECTOR(P, 4) r = in4.read(uint2(gid.x, y), gid.z); - out.write(r, gid.xy, gid.z); - return; - } -#endif // N >= 5 -#if N >= 6 - y -= pm.vdim[4]; - if (y < pm.vdim[5]) { - VECTOR(P, 4) r = in5.read(uint2(gid.x, y), gid.z); - out.write(r, gid.xy, gid.z); - return; - } -#endif // N >= 6 -} -#endif // V == VY - -#if V == VZ -kernel void FUNC(concat, R, N, z)(texture2d_array in0[[texture(0)]], - texture2d_array in1[[texture(1)]], -#if N >= 3 - texture2d_array in2[[texture(2)]], -#endif // N >= 3 -#if N >= 4 - texture2d_array in3[[texture(3)]], -#endif // N >= 4 -#if N >= 5 - texture2d_array in4[[texture(4)]], -#endif // N >= 5 -#if N >= 6 - texture2d_array in5[[texture(5)]], -#endif // N >= 6 - texture2d_array out[[texture(N)]], - constant ConcatParam& pm[[buffer(0)]], - uint3 gid[[thread_position_in_grid]]) { - int z = gid.z - pm.offset; - if (z < 0) return; - if (z < pm.vdim[0]) { - VECTOR(P, 4) r = in0.read(gid.xy, gid.z); - out.write(r, gid.xy, gid.z); - return; - } - z -= pm.vdim[0]; - if (z < pm.vdim[1]) { - VECTOR(P, 4) r = in1.read(gid.xy, z); - out.write(r, gid.xy, gid.z); - return; - } -#if N >= 3 - z -= pm.vdim[1]; - if (z < pm.vdim[2]) { - VECTOR(P, 4) r = in2.read(gid.xy, z); - out.write(r, gid.xy, gid.z); - return; - } -#endif // N >= 3 -#if N >= 4 - z -= pm.vdim[2]; - if (z < pm.vdim[3]) { - VECTOR(P, 4) r = in3.read(gid.xy, z); - out.write(r, gid.xy, gid.z); - return; - } -#endif // N >= 4 -#if N >= 5 - z -= pm.vdim[3]; - if (z < pm.vdim[4]) { - VECTOR(P, 4) r = in4.read(gid.xy, z); - out.write(r, gid.xy, gid.z); - return; - } -#endif // N >= 5 -#if N >= 6 - z -= pm.vdim[4]; - if (z < pm.vdim[5]) { - VECTOR(P, 4) r = in5.read(gid.xy, z); - out.write(r, gid.xy, gid.z); - return; - } -#endif // N >= 6 -} -#endif // V == VZ - -#endif // #ifdef P diff --git a/lite/backends/metal/metal_kernel/texture/ConcatKernel.metal b/lite/backends/metal/metal_kernel/texture/ConcatKernel.metal index 33195459962..e6f8c752b92 100644 --- a/lite/backends/metal/metal_kernel/texture/ConcatKernel.metal +++ b/lite/backends/metal/metal_kernel/texture/ConcatKernel.metal @@ -12,356 +12,181 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include - #include "Common.metal" - +#include using namespace metal; struct ConcatParam { int32_t odim[4]; int32_t axis; int32_t offset; + int32_t num; + int32_t v_; int32_t trans[4]; int32_t vdim[6]; }; -#define VNORMAL 1 -#define VX 2 -#define VY 3 -#define VZ 4 - -// R:input dim size N:input number V: direction - -// >> normal mode (loop mode) - -// >> fast mode -// only support concat_{2,3,4}_{2,3,4,5,6}_y_{float,half} -// only support concat_{3,4}_{2,3,4,5,6}_x_{float,half} -// only support concat_{1,2,3,4}_{2,3,4,5,6}_z_{float,half} - -// >> special model -// lens: (R=4, N=3, V=normal) -// lens: (R=2, N=3, V=normal) -// lens: (R=2, N=2, V=normal) -// lens: (R=4, N=2, V=z) -// ssd-ar: (R=4, N=3, V=z) -// ssd-ar: (R=3, N=2, V=y) -// ssd-ar: (R=3, N=5, V=x) -// ssd-ar: (R=2, N=5, V=x) -// ssd: (R=2, N=6, V=y), -// ssd: (R=3, N=6, V=y) -// genet: (R=4, N=2, V=normal) -// gesture recognizing: (R=2, N=3, V=x) - -#pragma mark - -#pragma mark normal - -#define V VNORMAL -#define R 4 -#define N 4 -#define P ftype -#include "ConcatKernel.inc.metal" -#undef P -#undef N -#undef R -#undef V - -#define V VNORMAL -#define R 4 -#define N 3 -#define P ftype -#include "ConcatKernel.inc.metal" -#undef P -#undef N -#undef R -#undef V - -#define V VNORMAL -#define R 4 -#define N 2 -#define P ftype -#include "ConcatKernel.inc.metal" -#undef P -#undef N -#undef R -#undef V - -#define V VNORMAL -#define R 3 -#define N 3 -#define P ftype -#include "ConcatKernel.inc.metal" -#undef P -#undef N -#undef R -#undef V - -#define V VNORMAL -#define R 3 -#define N 2 -#define P ftype -#include "ConcatKernel.inc.metal" -#undef P -#undef N -#undef R -#undef V - -#define V VNORMAL -#define R 2 -#define N 3 -#define P ftype -#include "ConcatKernel.inc.metal" -#undef P -#undef N -#undef R -#undef V - -#define V VNORMAL -#define R 2 -#define N 2 -#define P ftype -#include "ConcatKernel.inc.metal" -#undef P -#undef N -#undef R -#undef V - -#pragma mark - -#pragma mark z - -#define V VZ -#define R 4 -#define N 6 -#define P ftype -#include "ConcatKernel.inc.metal" -#undef P -#undef N -#undef R -#undef V - -#define V VZ -#define R 4 -#define N 5 -#define P ftype -#include "ConcatKernel.inc.metal" -#undef P -#undef N -#undef R -#undef V - -#define V VZ -#define R 4 -#define N 4 -#define P ftype -#include "ConcatKernel.inc.metal" -#undef P -#undef N -#undef R -#undef V - -#define V VZ -#define R 4 -#define N 3 -#define P ftype -#include "ConcatKernel.inc.metal" -#undef P -#undef N -#undef R -#undef V - -#define V VZ -#define R 4 -#define N 2 -#define P ftype -#include "ConcatKernel.inc.metal" -#undef P -#undef N -#undef R -#undef V - -#define V VZ -#define R 3 -#define N 5 -#define P ftype -#include "ConcatKernel.inc.metal" -#undef P -#undef N -#undef R -#undef V - -#define V VZ -#define R 3 -#define N 3 -#define P ftype -#include "ConcatKernel.inc.metal" -#undef P -#undef N -#undef R -#undef V - -#pragma mark - -#pragma mark x - -#define V VX -#define R 3 -#define N 6 -#define P ftype -#include "ConcatKernel.inc.metal" -#undef P -#undef N -#undef R -#undef V - -#define V VX -#define R 3 -#define N 5 -#define P ftype -#include "ConcatKernel.inc.metal" -#undef P -#undef N -#undef R -#undef V - -#define V VX -#define R 3 -#define N 4 -#define P ftype -#include "ConcatKernel.inc.metal" -#undef P -#undef N -#undef R -#undef V - -#define V VX -#define R 3 -#define N 3 -#define P ftype -#include "ConcatKernel.inc.metal" -#undef P -#undef N -#undef R -#undef V - -#define V VX -#define R 3 -#define N 2 -#define P ftype -#include "ConcatKernel.inc.metal" -#undef P -#undef N -#undef R -#undef V - -#define V VX -#define R 2 -#define N 6 -#define P ftype -#include "ConcatKernel.inc.metal" -#undef P -#undef N -#undef R -#undef V - -#define V VX -#define R 2 -#define N 5 -#define P ftype -#include "ConcatKernel.inc.metal" -#undef P -#undef N -#undef R -#undef V - -#define V VX -#define R 2 -#define N 4 -#define P ftype -#include "ConcatKernel.inc.metal" -#undef P -#undef N -#undef R -#undef V - -#define V VX -#define R 2 -#define N 3 -#define P ftype -#include "ConcatKernel.inc.metal" -#undef P -#undef N -#undef R -#undef V - -#pragma mark - -#pragma mark y - -#define V VY -#define R 4 -#define N 3 -#define P ftype -#include "ConcatKernel.inc.metal" -#undef P -#undef N -#undef R -#undef V - -#define V VY -#define R 3 -#define N 6 -#define P ftype -#include "ConcatKernel.inc.metal" -#undef P -#undef N -#undef R -#undef V - -#define V VY -#define R 3 -#define N 3 -#define P ftype -#include "ConcatKernel.inc.metal" -#undef P -#undef N -#undef R -#undef V - -#define V VY -#define R 3 -#define N 2 -#define P ftype -#include "ConcatKernel.inc.metal" -#undef P -#undef N -#undef R -#undef V - -#define V VY -#define R 2 -#define N 6 -#define P ftype -#include "ConcatKernel.inc.metal" -#undef P -#undef N -#undef R -#undef V - -#define V VY -#define R 2 -#define N 5 -#define P ftype -#include "ConcatKernel.inc.metal" -#undef P -#undef N -#undef R -#undef V - -#define V VY -#define R 2 -#define N 2 -#define P ftype -#include "ConcatKernel.inc.metal" -#undef P -#undef N -#undef R -#undef V +kernel void concat_normal(texture2d_array inx[[texture(0)]], + texture2d_array out[[texture(1)]], + texture2d_array in0[[texture(2)]], + texture2d_array in1[[texture(3)]], + texture2d_array in2[[texture(4)]], + texture2d_array in3[[texture(5)]], + texture2d_array in4[[texture(6)]], + texture2d_array in5[[texture(7)]], + constant ConcatParam& pm[[buffer(0)]], + uint3 gid[[thread_position_in_grid]]) { + ConcatParam cp = pm; + int n = pm.num; + int xyzn[4] = {int(gid.x), int(gid.y), int(gid.z), 0}, abcd[4], oxyzn[4]; + ftype4 r = inx.read(gid.xy, gid.z); + for (int i = 0; i < 4; i++) { + xyzn[3] = i; + xyzn2abcd_4(cp.odim[3], xyzn, abcd); + int k = abcd[cp.axis] - cp.offset; + if (k < 0) continue; + int j = 0; + for (; j < n; j++) { + if (k < cp.vdim[j]) { + break; + } + k -= cp.vdim[j]; + } + if (j == n) { + continue; + } + int ta = cp.odim[cp.axis]; + abcd[cp.axis] = k; + cp.odim[cp.axis] = cp.vdim[j]; + abcd2xyzn_4(cp.odim[3], abcd, oxyzn); + cp.odim[cp.axis] = ta; + switch (j) { + case 0: + r[i] = in0.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; + break; + case 1: + r[i] = in1.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; + break; + case 2: + r[i] = in2.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; + break; + case 3: + r[i] = in3.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; + break; + case 4: + r[i] = in4.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; + break; + case 5: + r[i] = in5.read(uint2(oxyzn[0], oxyzn[1]), oxyzn[2])[oxyzn[3]]; + break; + } + } + out.write(r, gid.xy, gid.z); +} + +kernel void concat(texture2d_array out[[texture(0)]], + texture2d_array in0[[texture(1)]], + texture2d_array in1[[texture(2)]], + texture2d_array in2[[texture(3)]], + texture2d_array in3[[texture(4)]], + texture2d_array in4[[texture(5)]], + texture2d_array in5[[texture(6)]], + constant ConcatParam& pm[[buffer(0)]], + uint3 gid[[thread_position_in_grid]]) { + int n = pm.num; + int v_ = pm.v_; + if (v_ == 2) { + int x = gid.x - pm.offset; + if (x < 0) return; + ftype4 r; + for (int i = 0; i < n; i++) { + if (i > 0) x -= pm.vdim[i - 1]; + if (x < pm.vdim[i]) { + switch (i) { + case 0: + r = in0.read(gid.xy, gid.z); + break; + case 1: + r = in1.read(uint2(x, gid.y), gid.z); + break; + case 2: + r = in2.read(uint2(x, gid.y), gid.z); + break; + case 3: + r = in3.read(uint2(x, gid.y), gid.z); + break; + case 4: + r = in4.read(uint2(x, gid.y), gid.z); + break; + case 5: + r = in5.read(uint2(x, gid.y), gid.z); + break; + } + out.write(r, gid.xy, gid.z); + return; + } + } + } else if (v_ == 3) { + int y = gid.y - pm.offset; + if (y < 0) return; + ftype4 r; + for (int i = 0; i < n; i++) { + if (i > 0) y -= pm.vdim[i - 1]; + if (y < pm.vdim[i]) { + switch (i) { + case 0: + r = in0.read(gid.xy, gid.z); + break; + case 1: + r = in1.read(uint2(gid.x, y), gid.z); + break; + case 2: + r = in2.read(uint2(gid.x, y), gid.z); + break; + case 3: + r = in3.read(uint2(gid.x, y), gid.z); + break; + case 4: + r = in4.read(uint2(gid.x, y), gid.z); + break; + case 5: + r = in5.read(uint2(gid.x, y), gid.z); + break; + } + out.write(r, gid.xy, gid.z); + return; + } + } + } else if (v_ == 4) { + int z = gid.z - pm.offset; + if (z < 0) return; + ftype4 r; + for (int i = 0; i < n; i++) { + if (i > 0) z -= pm.vdim[i - 1]; + if (z < pm.vdim[i]) { + switch (i) { + case 0: + r = in0.read(gid.xy, gid.z); + break; + case 1: + r = in1.read(gid.xy, z); + break; + case 2: + r = in2.read(gid.xy, z); + break; + case 3: + r = in3.read(gid.xy, z); + break; + case 4: + r = in4.read(gid.xy, z); + break; + case 5: + r = in5.read(gid.xy, z); + break; + } + out.write(r, gid.xy, gid.z); + return; + } + } + } +} diff --git a/lite/kernels/metal/image_op/concat_image_compute.mm b/lite/kernels/metal/image_op/concat_image_compute.mm index a98660cfb09..e6e8960f80c 100644 --- a/lite/kernels/metal/image_op/concat_image_compute.mm +++ b/lite/kernels/metal/image_op/concat_image_compute.mm @@ -56,13 +56,14 @@ int idx = 0; auto encoder = [backend commandEncoder]; - for (auto item : input_buffers_) { - [encoder setTexture:item->image() atIndex:(idx++)]; - } [encoder setTexture:output_buffer_->image() atIndex:(idx++)]; if (v_ == "normal") { - [encoder setTexture:output_buffer_->image() atIndex:(idx)]; + [encoder setTexture:output_buffer_->image() atIndex:(idx++)]; + } + for (auto item : input_buffers_) { + [encoder setTexture:item->image() atIndex:(idx++)]; } + [encoder setBuffer:params_buffer_->buffer() offset:(0) atIndex:(0)]; [backend dispatchEncoder:encoder pipline:pipline outTexture:outTexture]; @@ -72,7 +73,7 @@ void ConcatImageCompute::setup_without_mps() { const auto& param = this->Param(); int num = (int)param.x.size(); - + int vaxis = 0; int axis = int(4 - output_buffer_->tensor_dim_.size() + param.axis); auto* axis_tensor = param.axis_tensor; if (axis_tensor != nullptr) { @@ -132,9 +133,21 @@ odm[4 - orank + i] = (int)(output_buffer_->tensor_dim_[i]); } } + + if (v_ == "normal") + vaxis = 1; + else if (v_ == "x") + vaxis = 2; + else if (v_ == "y") + vaxis = 3; + else if (v_ == "z") + vaxis = 4; + ConcatMetalParam concat_params{{odm[0], odm[1], odm[2], odm[3]}, static_cast(axis), 0, + num, + vaxis, {transpose[0], transpose[1], transpose[2], transpose[3]}, {(int)vdim[0], (int)vdim[1], (int)vdim[2], (int)vdim[3], (int)vdim[4], (int)vdim[5]}}; @@ -142,7 +155,10 @@ std::make_shared(metal_context_, sizeof(concat_params), &concat_params); #ifdef LITE_WITH_METAL_FULL #else - function_name_ = "concat_" + std::to_string(orank) + "_" + std::to_string(num) + "_" + v_; + if (v_ == "normal") + function_name_ = "concat_normal"; + else + function_name_ = "concat"; #endif auto backend = (__bridge MetalContextImp*)metal_context_->backend(); pipline_ = [backend pipline:function_name_]; diff --git a/lite/kernels/metal/image_op/metal_params.h b/lite/kernels/metal/image_op/metal_params.h index e54bcd004a9..ca78fecd026 100644 --- a/lite/kernels/metal/image_op/metal_params.h +++ b/lite/kernels/metal/image_op/metal_params.h @@ -131,6 +131,8 @@ struct ConcatMetalParam { int odim[4]; int axis; int offset; + int num; + int v_; int trans[4]; int vdim[6]; };