From 4c103e93b86480277133e25966b8f0ac977d520d Mon Sep 17 00:00:00 2001 From: Yang Chen Date: Sat, 21 Oct 2023 01:10:44 -0700 Subject: [PATCH] Fixed kernel launch config for permute021 (#957) Summary: make sure we have grid_z to be less than 65535. Differential Revision: D50517471 --- .../common/tensor/permute021_common.py | 96 +++++++++++++++---- tests/unittest/ops/test_permute021.py | 3 + 2 files changed, 83 insertions(+), 16 deletions(-) diff --git a/python/aitemplate/backend/common/tensor/permute021_common.py b/python/aitemplate/backend/common/tensor/permute021_common.py index 86d9b1578..426693714 100644 --- a/python/aitemplate/backend/common/tensor/permute021_common.py +++ b/python/aitemplate/backend/common/tensor/permute021_common.py @@ -84,7 +84,36 @@ {{tensor_accessor_libs}} -template +// blockIdx.x -> ni +// blockIdx.y -> hwi +// blockIdx.z -> ci +__device__ __forceinline__ void block_fn_nhc(int32_t& ni, int32_t& hwi, int32_t& ci) { + ni = blockIdx.x; + hwi = blockIdx.y; + ci = blockIdx.z; +} + +// blockIdx.x -> ni +// blockIdx.y -> ci +// blockIdx.z -> hwi +__device__ __forceinline__ void block_fn_nch(int32_t& ni, int32_t& hwi, int32_t& ci) { + ni = blockIdx.x; + ci = blockIdx.y; + hwi = blockIdx.z; +} + +// blockIdx.x -> ci +// blockIdx.y -> hwi +// blockIdx.z -> ni +__device__ __forceinline__ void block_fn_chn(int32_t& ni, int32_t& hwi, int32_t& ci) { + ci = blockIdx.x; + hwi = blockIdx.y; + ni = blockIdx.z; +} + +using BlockFunc = void (*)(int32_t&, int32_t&, int32_t&); + +template __global__ void permute021_kernel(T *output, const T *input, const int64_t n, @@ -101,9 +130,11 @@ const int32_t tid = threadIdx.y * blockDim.x + threadIdx.x; const int32_t wid = tid / TILE_SIZE; const int32_t lid = tid % TILE_SIZE; - const int32_t ni = blockIdx.z; - const int32_t hwi0 = blockIdx.y * TILE_SIZE; - const int32_t ci0 = blockIdx.x * TILE_SIZE; + int32_t ni_tmp, hwi_tmp, ci_tmp; + BLOCK_FN(ni_tmp, hwi_tmp, ci_tmp); + const int32_t ni = ni_tmp; + const int32_t hwi0 = hwi_tmp * TILE_SIZE; + const int32_t ci0 = ci_tmp * TILE_SIZE; size_t input_idx = ni * hwc + (hwi0 + wid) * c + ci0; @@ -172,21 +203,54 @@ const int32_t x_dim1 = x_dims[rank-2]; const int32_t x_dim2 = x_dims[rank-1]; - const int64_t n = x_dim0; +#define THROW_INVALID_LAUNCH_CONFIG \ + throw std::runtime_error( \ + std::string("invalid cuda launch config: ") + \ + std::to_string(grid_c) + ", " + \ + std::to_string(grid_hw) + ", " + \ + std::to_string(grid_n)); + + const int32_t n = static_cast(x_dim0); const int32_t h = 1; const int32_t w = x_dim1; const int32_t c = x_dim2; - dim3 grid((c + TILE_SIZE - 1) / TILE_SIZE, (h * w + TILE_SIZE - 1) / TILE_SIZE, n); - dim3 block(TILE_SIZE, TILE_SIZE / CH_K); - permute021_kernel<{{lib_dtype}}><<>>( - static_cast<{{lib_dtype}}*>(out_ptr), - static_cast(in_ptr), - n, - h, - w, - c, - input_accessor - ); + const int32_t grid_c = (c + TILE_SIZE - 1) / TILE_SIZE; + const int32_t grid_hw = (h * w + TILE_SIZE - 1) / TILE_SIZE; + const int32_t grid_n = n; + constexpr int32_t max_grid_z = 65535; + constexpr int32_t max_grid_x = 2147483647; + if (grid_c > max_grid_x || grid_hw > max_grid_x || grid_n > max_grid_x) { + THROW_INVALID_LAUNCH_CONFIG + } + if ((grid_c <= max_grid_z && grid_hw <= max_grid_z && grid_n <= max_grid_z) || + (grid_c > max_grid_z && grid_hw <= max_grid_z && grid_n <= max_grid_z)) { + dim3 grid(grid_c, grid_hw, grid_n); + dim3 block(TILE_SIZE, TILE_SIZE / CH_K); + permute021_kernel<{{lib_dtype}}, block_fn_chn><<>>( + static_cast<{{lib_dtype}}*>(out_ptr), + static_cast(in_ptr), + n, h, w, c, input_accessor + ); + } else if (grid_n > max_grid_z && grid_hw <= max_grid_z && grid_c <= max_grid_z) { + dim3 grid(grid_n, grid_c, grid_hw); + dim3 block(TILE_SIZE, TILE_SIZE / CH_K); + permute021_kernel<{{lib_dtype}}, block_fn_nch><<>>( + static_cast<{{lib_dtype}}*>(out_ptr), + static_cast(in_ptr), + n, h, w, c, input_accessor + ); + } else if (grid_n > max_grid_z && grid_hw <= max_grid_z && grid_c <= max_grid_z) { + dim3 grid(grid_n, grid_hw, grid_c); + dim3 block(TILE_SIZE, TILE_SIZE / CH_K); + permute021_kernel<{{lib_dtype}}, block_fn_nhc><<>>( + static_cast<{{lib_dtype}}*>(out_ptr), + static_cast(in_ptr), + n, h, w, c, input_accessor + ); + } else { + THROW_INVALID_LAUNCH_CONFIG + } + } } // namespace diff --git a/tests/unittest/ops/test_permute021.py b/tests/unittest/ops/test_permute021.py index df7b2be39..76e49138a 100644 --- a/tests/unittest/ops/test_permute021.py +++ b/tests/unittest/ops/test_permute021.py @@ -69,6 +69,9 @@ def _test_permute_021( param(3, (2, 3, 4, 384, 262), (0, 1, 2, 4, 3)), param(4, (IntVar([2, 3]), 384, 262), (0, 2, 1)), param(5, (IntVar([2, 3, 4]), 5, 384, 262), (0, 1, 3, 2)), + param(6, (409600, 12, 16), (0, 2, 1)), + param(7, (12, 409600, 16), (0, 2, 1)), + param(8, (12, 16, 409600), (0, 2, 1)), ] ) def test_permute021_fp16(self, id, input_shape, dims):