Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OpenCL][Kernel] Add concat multi inputs kernel except channel is not aligned #6075

Merged
merged 7 commits into from
May 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lite/backends/opencl/cl_kernel/cl_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ limitations under the License. */
#define MAX_VALUE FLT_MAX
#define MIN_VALUE -FLT_MAX

#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))

/////////////////////////////////
// CL_DTYPE_float / CL_DTYPE_half
/////////////////////////////////
Expand Down
362 changes: 311 additions & 51 deletions lite/backends/opencl/cl_kernel/image/concat_kernel.cl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,317 @@ limitations under the License. */

#include <cl_common.h>

/********************************************************
* For case that All Axis C of inputs are aligned: Start
********************************************************/
#define CHECK_IDX \
int c_blk_idx = get_global_id(0); \
int w_idx = get_global_id(1); \
int nh_idx = get_global_id(2); \
if (c_blk_idx >= output_shape.y || \
w_idx >= output_shape.w || \
nh_idx >= output_shape.x * output_shape.z) { \
return; \
} \
CL_DTYPE4 result;


// axis = 0
// Calling enqueueCopyImage directly is also OK but may be slower than kernel impl.
#define DOConcat2InputAxis0 \
int n_idx = nh_idx / output_shape.z; \
int h_idx = nh_idx % output_shape.z; \
int boundary0 = input_shape0.x; /* N0 */ \
int boundary1 = boundary0 + input_shape1.x; /* N0 + N1 */ \
int2 input_pos; \
input_pos.x = c_blk_idx * input_shape0.w + w_idx; \
if (n_idx < boundary0) { \
input_pos.y = n_idx * input_shape0.z + h_idx; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input0, SAMPLER, input_pos); \
} else if (n_idx < boundary1) { \
input_pos.y = (n_idx - boundary0) * input_shape1.z + h_idx; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input1, SAMPLER, input_pos); \
}

#define DOConcat3InputAxis0 \
DOConcat2InputAxis0; \
int boundary2 = boundary1 + input_shape2.x; \
if (n_idx >= boundary1 && n_idx < boundary2) { \
input_pos.y = (n_idx - boundary1) * input_shape2.z + h_idx; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input2, SAMPLER, input_pos); \
}

#define DOConcat4InputAxis0 \
DOConcat3InputAxis0; \
int boundary3 = boundary2 + input_shape3.x; \
if (n_idx >= boundary2 && n_idx < boundary3) { \
input_pos.y = (n_idx - boundary2) * input_shape3.z + h_idx; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input3, SAMPLER, input_pos); \
}

#define DOConcat5InputAxis0 \
DOConcat4InputAxis0; \
int boundary4 = boundary3 + input_shape4.x; \
if (n_idx >= boundary3 && n_idx < boundary4) { \
input_pos.y = (n_idx - boundary3) * input_shape4.z + h_idx; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input4, SAMPLER, input_pos); \
}

#define DOConcat6InputAxis0 \
DOConcat5InputAxis0; \
int boundary5 = boundary4 + input_shape5.x; \
if (n_idx >= boundary4 && n_idx < boundary5) { \
input_pos.y = (n_idx - boundary4) * input_shape5.z + h_idx; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input5, SAMPLER, input_pos); \
}


// axis = 1
#define DOConcat2InputAxis1 \
int boundary0 = input_shape0.y; /* C_blk0 */ \
int boundary1 = boundary0 + input_shape1.y; /* C_blk0 + C_blk1 */ \
int2 input_pos; \
input_pos.y = nh_idx; \
if (c_blk_idx < boundary0) { \
input_pos.x = c_blk_idx * input_shape0.w + w_idx; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input0, SAMPLER, input_pos); \
} else if (c_blk_idx < boundary1) { \
input_pos.x = (c_blk_idx - boundary0) * input_shape1.w + w_idx; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input1, SAMPLER, input_pos); \
}

#define DOConcat3InputAxis1 \
DOConcat2InputAxis1; \
int boundary2 = boundary1 + input_shape2.y; \
if (c_blk_idx >= boundary1 && c_blk_idx < boundary2) { \
input_pos.x = (c_blk_idx - boundary1) * input_shape2.w + w_idx; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input2, SAMPLER, input_pos); \
}

#define DOConcat4InputAxis1 \
DOConcat3InputAxis1; \
int boundary3 = boundary2 + input_shape3.y; \
if (c_blk_idx >= boundary2 && c_blk_idx < boundary3) { \
input_pos.x = (c_blk_idx - boundary2) * input_shape3.w + w_idx; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input3, SAMPLER, input_pos); \
}

#define DOConcat5InputAxis1 \
DOConcat4InputAxis1; \
int boundary4 = boundary3 + input_shape4.y; \
if (c_blk_idx >= boundary3 && c_blk_idx < boundary4) { \
input_pos.x = (c_blk_idx - boundary3) * input_shape4.w + w_idx; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input4, SAMPLER, input_pos); \
}

#define DOConcat6InputAxis1 \
DOConcat5InputAxis1; \
int boundary5 = boundary4 + input_shape5.y; \
if (c_blk_idx >= boundary4 && c_blk_idx < boundary5) { \
input_pos.x = (c_blk_idx - boundary4) * input_shape5.w + w_idx; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input5, SAMPLER, input_pos); \
}


// axis = 2
#define DOConcat2InputAxis2 \
int n_idx = nh_idx / output_shape.z; \
int h_idx = nh_idx % output_shape.z; \
int boundary0 = input_shape0.z; /* H0 */ \
int boundary1 = boundary0 + input_shape1.z; /* H0 + H1 */ \
int2 input_pos; \
input_pos.x = c_blk_idx * input_shape0.w + w_idx; \
if (h_idx < boundary0) { \
input_pos.y = n_idx * input_shape0.z + h_idx; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input0, SAMPLER, input_pos); \
} else if (h_idx < boundary1) { \
input_pos.y = n_idx * input_shape1.z + h_idx - boundary0; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input1, SAMPLER, input_pos); \
}

#define DOConcat3InputAxis2 \
DOConcat2InputAxis2; \
int boundary2 = boundary1 + input_shape2.z; \
if (h_idx >= boundary1 && h_idx < boundary2) { \
input_pos.y = n_idx * input_shape2.z + h_idx - boundary1; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input2, SAMPLER, input_pos); \
}

#define DOConcat4InputAxis2 \
DOConcat3InputAxis2; \
int boundary3 = boundary2 + input_shape3.z; \
if (h_idx >= boundary2 && h_idx < boundary3) { \
input_pos.y = n_idx * input_shape3.z + h_idx - boundary2; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input3, SAMPLER, input_pos); \
}

#define DOConcat5InputAxis2 \
DOConcat4InputAxis2; \
int boundary4 = boundary3 + input_shape4.z; \
if (h_idx >= boundary3 && h_idx < boundary4) { \
input_pos.y = n_idx * input_shape4.z + h_idx - boundary3; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input4, SAMPLER, input_pos); \
}

#define DOConcat6InputAxis2 \
DOConcat5InputAxis2; \
int boundary5 = boundary4 + input_shape5.z; \
if (h_idx >= boundary4 && h_idx < boundary5) { \
input_pos.y = n_idx * input_shape5.z + h_idx - boundary4; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input5, SAMPLER, input_pos); \
}


// axis = 3
#define DOConcat2InputAxis3 \
int boundary0 = input_shape0.w; /* W0 */ \
int boundary1 = boundary0 + input_shape1.w; /* W0 + W1 */ \
int2 input_pos; \
input_pos.y = nh_idx; \
if (w_idx < boundary0) { \
input_pos.x = c_blk_idx * input_shape0.w + w_idx; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input0, SAMPLER, input_pos); \
} else if (w_idx < boundary1) { \
input_pos.x = c_blk_idx * input_shape1.w + w_idx - boundary0; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input1, SAMPLER, input_pos); \
}

#define DOConcat3InputAxis3 \
DOConcat2InputAxis3; \
int boundary2 = boundary1 + input_shape2.w; \
if (w_idx >= boundary1 && w_idx < boundary2) { \
input_pos.x = c_blk_idx * input_shape2.w + w_idx - boundary1; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input2, SAMPLER, input_pos); \
}

#define DOConcat4InputAxis3 \
DOConcat3InputAxis3; \
int boundary3 = boundary2 + input_shape3.w; \
if (w_idx >= boundary2 && w_idx < boundary3) { \
input_pos.x = c_blk_idx * input_shape3.w + w_idx - boundary2; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input3, SAMPLER, input_pos); \
}

#define DOConcat5InputAxis3 \
DOConcat4InputAxis3; \
int boundary4 = boundary3 + input_shape4.w; \
if (w_idx >= boundary3 && w_idx < boundary4) { \
input_pos.x = c_blk_idx * input_shape4.w + w_idx - boundary3; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input4, SAMPLER, input_pos); \
}

#define DOConcat6InputAxis3 \
DOConcat5InputAxis3; \
int boundary5 = boundary4 + input_shape5.w; \
if (w_idx >= boundary4 && w_idx < boundary5) { \
input_pos.x = c_blk_idx * input_shape5.w + w_idx - boundary4; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input5, SAMPLER, input_pos); \
}


#define WRITE_IMG_DATA \
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(c_blk_idx * output_shape.w + w_idx, nh_idx), result);

#define CONCAT2(Inputnum, Axis) \
__kernel void Concat##Inputnum##Axis(__read_only image2d_t input0, \
__read_only image2d_t input1, \
__write_only image2d_t output, \
int4 input_shape0, int4 input_shape1, \
int4 output_shape) { \
CHECK_IDX \
DOConcat##Inputnum##Axis \
WRITE_IMG_DATA \
}

#define CONCAT3(Inputnum, Axis) \
__kernel void Concat##Inputnum##Axis(__read_only image2d_t input0, \
__read_only image2d_t input1, \
__read_only image2d_t input2, \
__write_only image2d_t output, \
int4 input_shape0, int4 input_shape1, \
int4 input_shape2, \
int4 output_shape) { \
CHECK_IDX \
DOConcat##Inputnum##Axis \
WRITE_IMG_DATA \
}

#define CONCAT4(Inputnum, Axis) \
__kernel void Concat##Inputnum##Axis(__read_only image2d_t input0, \
__read_only image2d_t input1, \
__read_only image2d_t input2, \
__read_only image2d_t input3, \
__write_only image2d_t output, \
int4 input_shape0, int4 input_shape1, \
int4 input_shape2, int4 input_shape3, \
int4 output_shape) { \
CHECK_IDX \
DOConcat##Inputnum##Axis \
WRITE_IMG_DATA \
}

#define CONCAT5(Inputnum, Axis) \
__kernel void Concat##Inputnum##Axis(__read_only image2d_t input0, \
__read_only image2d_t input1, \
__read_only image2d_t input2, \
__read_only image2d_t input3, \
__read_only image2d_t input4, \
__write_only image2d_t output, \
int4 input_shape0, int4 input_shape1, \
int4 input_shape2, int4 input_shape3, \
int4 input_shape4, \
int4 output_shape) { \
CHECK_IDX \
DOConcat##Inputnum##Axis \
WRITE_IMG_DATA \
}

#define CONCAT6(Inputnum, Axis) \
__kernel void Concat##Inputnum##Axis(__read_only image2d_t input0, \
__read_only image2d_t input1, \
__read_only image2d_t input2, \
__read_only image2d_t input3, \
__read_only image2d_t input4, \
__read_only image2d_t input5, \
__write_only image2d_t output, \
int4 input_shape0, int4 input_shape1, \
int4 input_shape2, int4 input_shape3, \
int4 input_shape4, int4 input_shape5, \
int4 output_shape) { \
CHECK_IDX \
DOConcat##Inputnum##Axis \
WRITE_IMG_DATA \
}

// axis = 0
CONCAT2(2Input, Axis0)
CONCAT3(3Input, Axis0)
CONCAT4(4Input, Axis0)
CONCAT5(5Input, Axis0)
CONCAT6(6Input, Axis0)
// axis = 1
CONCAT2(2Input, Axis1)
CONCAT3(3Input, Axis1)
CONCAT4(4Input, Axis1)
CONCAT5(5Input, Axis1)
CONCAT6(6Input, Axis1)
// axis = 2
CONCAT2(2Input, Axis2)
CONCAT3(3Input, Axis2)
CONCAT4(4Input, Axis2)
CONCAT5(5Input, Axis2)
CONCAT6(6Input, Axis2)
// axis = 3
CONCAT2(2Input, Axis3)
CONCAT3(3Input, Axis3)
CONCAT4(4Input, Axis3)
CONCAT5(5Input, Axis3)
CONCAT6(6Input, Axis3)
/********************************************************
* For case that All Axis C of inputs are aligned: End
********************************************************/


// deprecated
__kernel void concatByCWith2Inputs(
__write_only image2d_t output_image,
Expand Down Expand Up @@ -222,57 +533,6 @@ __kernel void concatByCWith4Inputs(
}


// deprecated
__kernel void concatByH(__read_only image2d_t input_image,
__write_only image2d_t output_image,
__private const int out_W,
__private const int out_H_Start) {

const int in_c = get_global_id(0);
const int in_w = get_global_id(1);
const int in_nh = get_global_id(2);

int2 input_pos;
input_pos.x = in_c * out_W + in_w;
input_pos.y = in_nh;

CL_DTYPE4 input;
input = READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, SAMPLER,input_pos);

int2 output_pos;
output_pos.x = input_pos.x;
output_pos.y = out_H_Start + input_pos.y;

WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos, input);

}


// deprecated
__kernel void concatByW(__read_only image2d_t input_image,
__write_only image2d_t output_image,
__private const int in_W,
__private const int pre_Width,
__private const int out_Width) {

const int in_c = get_global_id(0);
const int in_w = get_global_id(1);
const int in_nh = get_global_id(2);

int2 input_pos;
input_pos.x = in_c * in_W + in_w;
input_pos.y = in_nh;

CL_DTYPE4 input;
input = READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, SAMPLER,input_pos);

int2 output_pos;
output_pos.x = input_pos.x + pre_Width + out_Width * in_c;
output_pos.y = input_pos.y;
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos, input);
}


__kernel void concat2(__read_only image2d_t input0,
__read_only image2d_t input1,
__write_only image2d_t output,
Expand Down
Loading