Skip to content

Commit

Permalink
[OpenCL][Kernel] Add concat multi inputs kernel except channel is not…
Browse files Browse the repository at this point in the history
… aligned (#6075)
  • Loading branch information
zhaoyang-star authored May 13, 2021
1 parent bfbe03e commit 6478459
Show file tree
Hide file tree
Showing 4 changed files with 424 additions and 107 deletions.
2 changes: 2 additions & 0 deletions lite/backends/opencl/cl_kernel/cl_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ limitations under the License. */
#define MAX_VALUE FLT_MAX
#define MIN_VALUE -FLT_MAX

#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))

/////////////////////////////////
// CL_DTYPE_float / CL_DTYPE_half
/////////////////////////////////
Expand Down
362 changes: 311 additions & 51 deletions lite/backends/opencl/cl_kernel/image/concat_kernel.cl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,317 @@ limitations under the License. */

#include <cl_common.h>

/********************************************************
* For case that All Axis C of inputs are aligned: Start
********************************************************/
#define CHECK_IDX \
int c_blk_idx = get_global_id(0); \
int w_idx = get_global_id(1); \
int nh_idx = get_global_id(2); \
if (c_blk_idx >= output_shape.y || \
w_idx >= output_shape.w || \
nh_idx >= output_shape.x * output_shape.z) { \
return; \
} \
CL_DTYPE4 result;


// axis = 0
// Calling enqueueCopyImage directly is also OK but may be slower than kernel impl.
#define DOConcat2InputAxis0 \
int n_idx = nh_idx / output_shape.z; \
int h_idx = nh_idx % output_shape.z; \
int boundary0 = input_shape0.x; /* N0 */ \
int boundary1 = boundary0 + input_shape1.x; /* N0 + N1 */ \
int2 input_pos; \
input_pos.x = c_blk_idx * input_shape0.w + w_idx; \
if (n_idx < boundary0) { \
input_pos.y = n_idx * input_shape0.z + h_idx; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input0, SAMPLER, input_pos); \
} else if (n_idx < boundary1) { \
input_pos.y = (n_idx - boundary0) * input_shape1.z + h_idx; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input1, SAMPLER, input_pos); \
}

#define DOConcat3InputAxis0 \
DOConcat2InputAxis0; \
int boundary2 = boundary1 + input_shape2.x; \
if (n_idx >= boundary1 && n_idx < boundary2) { \
input_pos.y = (n_idx - boundary1) * input_shape2.z + h_idx; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input2, SAMPLER, input_pos); \
}

#define DOConcat4InputAxis0 \
DOConcat3InputAxis0; \
int boundary3 = boundary2 + input_shape3.x; \
if (n_idx >= boundary2 && n_idx < boundary3) { \
input_pos.y = (n_idx - boundary2) * input_shape3.z + h_idx; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input3, SAMPLER, input_pos); \
}

#define DOConcat5InputAxis0 \
DOConcat4InputAxis0; \
int boundary4 = boundary3 + input_shape4.x; \
if (n_idx >= boundary3 && n_idx < boundary4) { \
input_pos.y = (n_idx - boundary3) * input_shape4.z + h_idx; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input4, SAMPLER, input_pos); \
}

#define DOConcat6InputAxis0 \
DOConcat5InputAxis0; \
int boundary5 = boundary4 + input_shape5.x; \
if (n_idx >= boundary4 && n_idx < boundary5) { \
input_pos.y = (n_idx - boundary4) * input_shape5.z + h_idx; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input5, SAMPLER, input_pos); \
}


// axis = 1
#define DOConcat2InputAxis1 \
int boundary0 = input_shape0.y; /* C_blk0 */ \
int boundary1 = boundary0 + input_shape1.y; /* C_blk0 + C_blk1 */ \
int2 input_pos; \
input_pos.y = nh_idx; \
if (c_blk_idx < boundary0) { \
input_pos.x = c_blk_idx * input_shape0.w + w_idx; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input0, SAMPLER, input_pos); \
} else if (c_blk_idx < boundary1) { \
input_pos.x = (c_blk_idx - boundary0) * input_shape1.w + w_idx; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input1, SAMPLER, input_pos); \
}

#define DOConcat3InputAxis1 \
DOConcat2InputAxis1; \
int boundary2 = boundary1 + input_shape2.y; \
if (c_blk_idx >= boundary1 && c_blk_idx < boundary2) { \
input_pos.x = (c_blk_idx - boundary1) * input_shape2.w + w_idx; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input2, SAMPLER, input_pos); \
}

#define DOConcat4InputAxis1 \
DOConcat3InputAxis1; \
int boundary3 = boundary2 + input_shape3.y; \
if (c_blk_idx >= boundary2 && c_blk_idx < boundary3) { \
input_pos.x = (c_blk_idx - boundary2) * input_shape3.w + w_idx; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input3, SAMPLER, input_pos); \
}

#define DOConcat5InputAxis1 \
DOConcat4InputAxis1; \
int boundary4 = boundary3 + input_shape4.y; \
if (c_blk_idx >= boundary3 && c_blk_idx < boundary4) { \
input_pos.x = (c_blk_idx - boundary3) * input_shape4.w + w_idx; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input4, SAMPLER, input_pos); \
}

#define DOConcat6InputAxis1 \
DOConcat5InputAxis1; \
int boundary5 = boundary4 + input_shape5.y; \
if (c_blk_idx >= boundary4 && c_blk_idx < boundary5) { \
input_pos.x = (c_blk_idx - boundary4) * input_shape5.w + w_idx; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input5, SAMPLER, input_pos); \
}


// axis = 2
#define DOConcat2InputAxis2 \
int n_idx = nh_idx / output_shape.z; \
int h_idx = nh_idx % output_shape.z; \
int boundary0 = input_shape0.z; /* H0 */ \
int boundary1 = boundary0 + input_shape1.z; /* H0 + H1 */ \
int2 input_pos; \
input_pos.x = c_blk_idx * input_shape0.w + w_idx; \
if (h_idx < boundary0) { \
input_pos.y = n_idx * input_shape0.z + h_idx; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input0, SAMPLER, input_pos); \
} else if (h_idx < boundary1) { \
input_pos.y = n_idx * input_shape1.z + h_idx - boundary0; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input1, SAMPLER, input_pos); \
}

#define DOConcat3InputAxis2 \
DOConcat2InputAxis2; \
int boundary2 = boundary1 + input_shape2.z; \
if (h_idx >= boundary1 && h_idx < boundary2) { \
input_pos.y = n_idx * input_shape2.z + h_idx - boundary1; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input2, SAMPLER, input_pos); \
}

#define DOConcat4InputAxis2 \
DOConcat3InputAxis2; \
int boundary3 = boundary2 + input_shape3.z; \
if (h_idx >= boundary2 && h_idx < boundary3) { \
input_pos.y = n_idx * input_shape3.z + h_idx - boundary2; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input3, SAMPLER, input_pos); \
}

#define DOConcat5InputAxis2 \
DOConcat4InputAxis2; \
int boundary4 = boundary3 + input_shape4.z; \
if (h_idx >= boundary3 && h_idx < boundary4) { \
input_pos.y = n_idx * input_shape4.z + h_idx - boundary3; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input4, SAMPLER, input_pos); \
}

#define DOConcat6InputAxis2 \
DOConcat5InputAxis2; \
int boundary5 = boundary4 + input_shape5.z; \
if (h_idx >= boundary4 && h_idx < boundary5) { \
input_pos.y = n_idx * input_shape5.z + h_idx - boundary4; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input5, SAMPLER, input_pos); \
}


// axis = 3
#define DOConcat2InputAxis3 \
int boundary0 = input_shape0.w; /* W0 */ \
int boundary1 = boundary0 + input_shape1.w; /* W0 + W1 */ \
int2 input_pos; \
input_pos.y = nh_idx; \
if (w_idx < boundary0) { \
input_pos.x = c_blk_idx * input_shape0.w + w_idx; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input0, SAMPLER, input_pos); \
} else if (w_idx < boundary1) { \
input_pos.x = c_blk_idx * input_shape1.w + w_idx - boundary0; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input1, SAMPLER, input_pos); \
}

#define DOConcat3InputAxis3 \
DOConcat2InputAxis3; \
int boundary2 = boundary1 + input_shape2.w; \
if (w_idx >= boundary1 && w_idx < boundary2) { \
input_pos.x = c_blk_idx * input_shape2.w + w_idx - boundary1; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input2, SAMPLER, input_pos); \
}

#define DOConcat4InputAxis3 \
DOConcat3InputAxis3; \
int boundary3 = boundary2 + input_shape3.w; \
if (w_idx >= boundary2 && w_idx < boundary3) { \
input_pos.x = c_blk_idx * input_shape3.w + w_idx - boundary2; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input3, SAMPLER, input_pos); \
}

#define DOConcat5InputAxis3 \
DOConcat4InputAxis3; \
int boundary4 = boundary3 + input_shape4.w; \
if (w_idx >= boundary3 && w_idx < boundary4) { \
input_pos.x = c_blk_idx * input_shape4.w + w_idx - boundary3; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input4, SAMPLER, input_pos); \
}

#define DOConcat6InputAxis3 \
DOConcat5InputAxis3; \
int boundary5 = boundary4 + input_shape5.w; \
if (w_idx >= boundary4 && w_idx < boundary5) { \
input_pos.x = c_blk_idx * input_shape5.w + w_idx - boundary4; \
result = READ_IMG_TYPE(CL_DTYPE_CHAR, input5, SAMPLER, input_pos); \
}


#define WRITE_IMG_DATA \
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(c_blk_idx * output_shape.w + w_idx, nh_idx), result);

#define CONCAT2(Inputnum, Axis) \
__kernel void Concat##Inputnum##Axis(__read_only image2d_t input0, \
__read_only image2d_t input1, \
__write_only image2d_t output, \
int4 input_shape0, int4 input_shape1, \
int4 output_shape) { \
CHECK_IDX \
DOConcat##Inputnum##Axis \
WRITE_IMG_DATA \
}

#define CONCAT3(Inputnum, Axis) \
__kernel void Concat##Inputnum##Axis(__read_only image2d_t input0, \
__read_only image2d_t input1, \
__read_only image2d_t input2, \
__write_only image2d_t output, \
int4 input_shape0, int4 input_shape1, \
int4 input_shape2, \
int4 output_shape) { \
CHECK_IDX \
DOConcat##Inputnum##Axis \
WRITE_IMG_DATA \
}

#define CONCAT4(Inputnum, Axis) \
__kernel void Concat##Inputnum##Axis(__read_only image2d_t input0, \
__read_only image2d_t input1, \
__read_only image2d_t input2, \
__read_only image2d_t input3, \
__write_only image2d_t output, \
int4 input_shape0, int4 input_shape1, \
int4 input_shape2, int4 input_shape3, \
int4 output_shape) { \
CHECK_IDX \
DOConcat##Inputnum##Axis \
WRITE_IMG_DATA \
}

#define CONCAT5(Inputnum, Axis) \
__kernel void Concat##Inputnum##Axis(__read_only image2d_t input0, \
__read_only image2d_t input1, \
__read_only image2d_t input2, \
__read_only image2d_t input3, \
__read_only image2d_t input4, \
__write_only image2d_t output, \
int4 input_shape0, int4 input_shape1, \
int4 input_shape2, int4 input_shape3, \
int4 input_shape4, \
int4 output_shape) { \
CHECK_IDX \
DOConcat##Inputnum##Axis \
WRITE_IMG_DATA \
}

#define CONCAT6(Inputnum, Axis) \
__kernel void Concat##Inputnum##Axis(__read_only image2d_t input0, \
__read_only image2d_t input1, \
__read_only image2d_t input2, \
__read_only image2d_t input3, \
__read_only image2d_t input4, \
__read_only image2d_t input5, \
__write_only image2d_t output, \
int4 input_shape0, int4 input_shape1, \
int4 input_shape2, int4 input_shape3, \
int4 input_shape4, int4 input_shape5, \
int4 output_shape) { \
CHECK_IDX \
DOConcat##Inputnum##Axis \
WRITE_IMG_DATA \
}

// axis = 0
CONCAT2(2Input, Axis0)
CONCAT3(3Input, Axis0)
CONCAT4(4Input, Axis0)
CONCAT5(5Input, Axis0)
CONCAT6(6Input, Axis0)
// axis = 1
CONCAT2(2Input, Axis1)
CONCAT3(3Input, Axis1)
CONCAT4(4Input, Axis1)
CONCAT5(5Input, Axis1)
CONCAT6(6Input, Axis1)
// axis = 2
CONCAT2(2Input, Axis2)
CONCAT3(3Input, Axis2)
CONCAT4(4Input, Axis2)
CONCAT5(5Input, Axis2)
CONCAT6(6Input, Axis2)
// axis = 3
CONCAT2(2Input, Axis3)
CONCAT3(3Input, Axis3)
CONCAT4(4Input, Axis3)
CONCAT5(5Input, Axis3)
CONCAT6(6Input, Axis3)
/********************************************************
* For case that All Axis C of inputs are aligned: End
********************************************************/


// deprecated
__kernel void concatByCWith2Inputs(
__write_only image2d_t output_image,
Expand Down Expand Up @@ -222,57 +533,6 @@ __kernel void concatByCWith4Inputs(
}


// deprecated
__kernel void concatByH(__read_only image2d_t input_image,
__write_only image2d_t output_image,
__private const int out_W,
__private const int out_H_Start) {

const int in_c = get_global_id(0);
const int in_w = get_global_id(1);
const int in_nh = get_global_id(2);

int2 input_pos;
input_pos.x = in_c * out_W + in_w;
input_pos.y = in_nh;

CL_DTYPE4 input;
input = READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, SAMPLER,input_pos);

int2 output_pos;
output_pos.x = input_pos.x;
output_pos.y = out_H_Start + input_pos.y;

WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos, input);

}


// deprecated
__kernel void concatByW(__read_only image2d_t input_image,
__write_only image2d_t output_image,
__private const int in_W,
__private const int pre_Width,
__private const int out_Width) {

const int in_c = get_global_id(0);
const int in_w = get_global_id(1);
const int in_nh = get_global_id(2);

int2 input_pos;
input_pos.x = in_c * in_W + in_w;
input_pos.y = in_nh;

CL_DTYPE4 input;
input = READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, SAMPLER,input_pos);

int2 output_pos;
output_pos.x = input_pos.x + pre_Width + out_Width * in_c;
output_pos.y = input_pos.y;
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos, input);
}


__kernel void concat2(__read_only image2d_t input0,
__read_only image2d_t input1,
__write_only image2d_t output,
Expand Down
Loading

0 comments on commit 6478459

Please sign in to comment.