From fdc052e84b7e7825f5e367a36944cb306621e45f Mon Sep 17 00:00:00 2001 From: BinZheng <38182684+Wickyzheng@users.noreply.github.com> Date: Fri, 6 Jan 2023 15:17:34 +0800 Subject: [PATCH] [Enhance] Optimize the performace of ms_deform_attn for MLU device (#2510) * ms_opt * ms_opt * ms_opt * ms_opt * ms_opt * [Feature] ms_deform_attn performance optimization * [Feature] ms_deform_attn performance optimization * [Feature] ms_deform_attn performance optimization --- .../common/mlu/ms_deform_attn_mlu_kernel.mlu | 1105 +++++++++++++++-- .../csrc/pytorch/mlu/ms_deform_attn_mlu.cpp | 143 ++- 2 files changed, 1137 insertions(+), 111 deletions(-) diff --git a/mmcv/ops/csrc/common/mlu/ms_deform_attn_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/ms_deform_attn_mlu_kernel.mlu index 7899e52cd3..6aab1dae21 100644 --- a/mmcv/ops/csrc/common/mlu/ms_deform_attn_mlu_kernel.mlu +++ b/mmcv/ops/csrc/common/mlu/ms_deform_attn_mlu_kernel.mlu @@ -42,15 +42,16 @@ ****************************************************************************************/ #define TWELVE_SPLIT 12 -#define ALIGN_NUM 64 +#define ALIGN_NUM 32 #define ALIGN_NUM_FOR_REDUCE 32 +#define LEN_FLOAT sizeof(float) __nram__ char nram_buffer[MAX_NRAM_SIZE]; template __mlu_func__ void loadNeighborPointsData( const T *data_value_gdram, T *data_value_p1_nram, T *data_value_p2_nram, - T *data_value_p3_nram, T *data_value_p4_nram, const size_t deal_num, + T *data_value_p3_nram, T *data_value_p4_nram, const size_t &deal_num, const int32_t &width, const int32_t &height, const int32_t &num_heads, const int32_t &channels, const T &x, const T &y, const int32_t &head_idx) { const int32_t w_low = floorf(x); @@ -100,11 +101,11 @@ __mlu_func__ void loadNeighborPointsData( } template -__mlu_func__ void bilinearInterpolation( +__mlu_func__ void computeMsDeformAttn( T *data_value_p1_nram, T *data_value_p2_nram, T *data_value_p3_nram, T *data_value_p4_nram, T *sample_point_value, T *auxiliary_b, - const size_t deal_num, const int32_t &width, const int32_t &height, - const T &x, const T &y) { + T *data_col_nram, const T &weight, const size_t &deal_num, + const int32_t &width, const int32_t &height, const T &x, const T &y) { const int32_t w_low = floorf(x); const int32_t h_low = floorf(y); const int32_t w_high = w_low + 1; @@ -156,10 +157,15 @@ __mlu_func__ void bilinearInterpolation( __bang_add((T *)sample_point_value, (T *)sample_point_value, (T *)auxiliary_b, deal_num); } + + __bang_mul_scalar((T *)sample_point_value, (T *)sample_point_value, (T)weight, + deal_num); + __bang_add((T *)data_col_nram, (T *)data_col_nram, (T *)sample_point_value, + deal_num); } template -__mlu_global__ void MLUKernelMsDeformAttnForward( +__mlu_global__ void MLUKernelMsDeformAttnForwardDefault( const char *data_value_gdram, const char *data_spatial_shapes_gdram, const char *data_level_start_index_gdram, const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram, @@ -346,7 +352,7 @@ __mlu_global__ void MLUKernelMsDeformAttnForward( // compute if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) { - bilinearInterpolation( + computeMsDeformAttn( (T *)(ping_data_value_p1_nram + ((level_idx * num_points + point_idx) % 2) * ping_pong_gap), @@ -359,15 +365,10 @@ __mlu_global__ void MLUKernelMsDeformAttnForward( (T *)(ping_data_value_p4_nram + ((level_idx * num_points + point_idx) % 2) * ping_pong_gap), - (T *)auxiliary_a, (T *)auxiliary_b, span_num_deal, spatial_w, - spatial_h, x, y); - __bang_mul_scalar((T *)auxiliary_a, (T *)auxiliary_a, (T)weight, - span_num_deal); - __bang_add((T *)(ping_data_col_nram + - data_col_ping_pong_idx * ping_pong_gap), - (T *)(ping_data_col_nram + - data_col_ping_pong_idx * ping_pong_gap), - (T *)auxiliary_a, span_num_deal); + (T *)auxiliary_a, (T *)auxiliary_b, + (T *)(ping_data_col_nram + + data_col_ping_pong_idx * ping_pong_gap), + weight, span_num_deal, spatial_w, spatial_h, x, y); } spatial_w = spatial_w_next_point; @@ -500,7 +501,459 @@ __mlu_global__ void MLUKernelMsDeformAttnForward( // compute if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) { - bilinearInterpolation( + computeMsDeformAttn( + (T *)(ping_data_value_p1_nram + + ((level_idx * num_points + point_idx) % 2) * + ping_pong_gap), + (T *)(ping_data_value_p2_nram + + ((level_idx * num_points + point_idx) % 2) * + ping_pong_gap), + (T *)(ping_data_value_p3_nram + + ((level_idx * num_points + point_idx) % 2) * + ping_pong_gap), + (T *)(ping_data_value_p4_nram + + ((level_idx * num_points + point_idx) % 2) * + ping_pong_gap), + (T *)auxiliary_a, (T *)auxiliary_b, + (T *)(ping_data_col_nram + + data_col_ping_pong_idx * ping_pong_gap), + weight, channels_align_rem, spatial_w, spatial_h, x, y); + } + + spatial_w = spatial_w_next_point; + spatial_h = spatial_h_next_point; + weight = weight_next_point; + x = x_next_point; + y = y_next_point; + __asm__ volatile("sync;"); + } + } + // store + __memcpy_async( + data_col_gdram_start + channels_seg_num * span_num_deal * sizeof(T), + ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap, + channels_rem * sizeof(T), NRAM2GDRAM); + data_col_ping_pong_idx = (data_col_ping_pong_idx + 1) % 2; + } + } + __asm__ volatile("sync;"); + return; +} + +template +__mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel( + const char *data_value_gdram, const char *data_spatial_shapes_gdram, + const char *data_level_start_index_gdram, + const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram, + const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, + const int32_t channels, const int32_t num_levels, const int32_t num_queries, + const int32_t num_points, char *data_col_gdram) { + if (coreId == 0x80) { + return; + } + + const size_t spatial_size = + PAD_UP(num_levels * 2 * sizeof(int32_t), NFU_ALIGN_SIZE); + const size_t level_start_index_size = + PAD_UP(num_levels * sizeof(int32_t), NFU_ALIGN_SIZE); + size_t sampling_loc_size = + PAD_UP(num_levels * num_points * 2 * sizeof(T), NFU_ALIGN_SIZE); + size_t attn_weight_size = + PAD_UP(num_levels * num_points * sizeof(T), NFU_ALIGN_SIZE); + size_t span_num_deal = + PAD_DOWN((MAX_NRAM_SIZE - spatial_size - level_start_index_size - + sampling_loc_size - attn_weight_size) / + TWELVE_SPLIT / sizeof(T), + NFU_ALIGN_SIZE); + const int32_t channels_seg_num = channels / span_num_deal; + const size_t channels_rem = channels % span_num_deal; + int32_t load_loc_weight_idx = 0; + int32_t load_loc_weight_seg = 1; + if (channels_seg_num == 0) { + span_num_deal = PAD_UP(channels, NFU_ALIGN_SIZE); + attn_weight_size = + PAD_DOWN((MAX_NRAM_SIZE - spatial_size - level_start_index_size - + TWELVE_SPLIT * span_num_deal * sizeof(T)) / + 3, + num_levels * num_points * sizeof(T)); + attn_weight_size = PAD_DOWN(attn_weight_size, NFU_ALIGN_SIZE); + sampling_loc_size = attn_weight_size * 2; + load_loc_weight_seg = + attn_weight_size / (num_levels * num_points * sizeof(T)); + } + +#if __BANG_ARCH__ < 322 + const size_t align_num = NFU_ALIGN_SIZE; + const size_t channels_align_rem = CEIL_ALIGN(channels_rem, align_num); +#endif + char *data_spatial_shapes_nram = nram_buffer; + char *data_level_start_index_nram = data_spatial_shapes_nram + spatial_size; + char *data_sampling_loc_nram = + data_level_start_index_nram + level_start_index_size; + char *data_attn_weight_nram = data_sampling_loc_nram + sampling_loc_size; + char *ping_data_value_p1_nram = data_attn_weight_nram + attn_weight_size; + char *ping_data_value_p2_nram = + ping_data_value_p1_nram + span_num_deal * sizeof(T); + char *ping_data_value_p3_nram = + ping_data_value_p2_nram + span_num_deal * sizeof(T); + char *ping_data_value_p4_nram = + ping_data_value_p3_nram + span_num_deal * sizeof(T); + char *ping_data_col_nram = + ping_data_value_p4_nram + span_num_deal * sizeof(T); + char *pong_data_value_p1_nram = + ping_data_col_nram + span_num_deal * sizeof(T); + char *pong_data_value_p2_nram = + pong_data_value_p1_nram + span_num_deal * sizeof(T); + char *pong_data_value_p3_nram = + pong_data_value_p2_nram + span_num_deal * sizeof(T); + char *pong_data_value_p4_nram = + pong_data_value_p3_nram + span_num_deal * sizeof(T); + char *pong_data_col_nram = + pong_data_value_p4_nram + span_num_deal * sizeof(T); + char *auxiliary_a = pong_data_col_nram + span_num_deal * sizeof(T); + char *auxiliary_b = auxiliary_a + span_num_deal * sizeof(T); + const size_t ping_pong_gap = 5 * span_num_deal * sizeof(T); + size_t data_col_ping_pong_idx = 0; + + const int32_t block_num_rem = + (batch_size * num_queries * num_heads) % taskDim; + const int32_t block_num_per_core = + taskId < block_num_rem + ? (batch_size * num_queries * num_heads) / taskDim + 1 + : (batch_size * num_queries * num_heads) / taskDim; + const int32_t idx_start = taskId < block_num_rem + ? taskId * block_num_per_core + : taskId * block_num_per_core + block_num_rem; + + __memcpy_async(data_spatial_shapes_nram, data_spatial_shapes_gdram, + num_levels * 2 * sizeof(int32_t), GDRAM2NRAM); + __memcpy_async(data_level_start_index_nram, data_level_start_index_gdram, + num_levels * sizeof(int32_t), GDRAM2NRAM); + + for (int32_t cur_idx = idx_start; cur_idx < idx_start + block_num_per_core; + ++cur_idx) { + // cur_idx = batch_idx * num_queries * num_heads + query_idx * num_heads + + // head_idx + const int32_t head_idx = cur_idx % num_heads; + const int32_t batch_idx = (cur_idx / num_heads) / num_queries; + + const char *data_value_gdram_start = + data_value_gdram + + batch_idx * num_keys * num_heads * channels * sizeof(T); + char *data_col_gdram_start = + data_col_gdram + cur_idx * channels * sizeof(T); + + if (load_loc_weight_seg == 1 || + (load_loc_weight_idx % load_loc_weight_seg) == 0) { + const char *data_sampling_loc_gdram_start = + data_sampling_loc_gdram + + cur_idx * num_levels * num_points * 2 * sizeof(T); + const char *data_attn_weight_gdram_start = + data_attn_weight_gdram + + cur_idx * num_levels * num_points * sizeof(T); + const int32_t load_loc_weight_size = + (block_num_per_core - load_loc_weight_idx) < load_loc_weight_seg + ? block_num_per_core - load_loc_weight_idx + : load_loc_weight_seg; + __memcpy_async( + data_sampling_loc_nram, data_sampling_loc_gdram_start, + load_loc_weight_size * num_levels * num_points * 2 * sizeof(T), + GDRAM2NRAM); + __memcpy_async(data_attn_weight_nram, data_attn_weight_gdram_start, + load_loc_weight_size * num_levels * num_points * sizeof(T), + GDRAM2NRAM); + __asm__ volatile("sync;"); + } + const int32_t load_loc_weight_offset = + (load_loc_weight_idx % load_loc_weight_seg) * num_levels * num_points; + + for (int32_t c_seg_idx = 0; c_seg_idx < channels_seg_num; ++c_seg_idx) { + __bang_write_value( + (T *)(ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap), + span_num_deal, (T)0); + // load data + // level_idx = 0, point_idx = 0 + int32_t spatial_h = ((int32_t *)data_spatial_shapes_nram)[0]; + int32_t spatial_w = ((int32_t *)data_spatial_shapes_nram)[1]; + const char *data_value_ptr = + data_value_gdram_start + c_seg_idx * span_num_deal * sizeof(T); + T loc_w = ((T *)data_sampling_loc_nram)[load_loc_weight_offset * 2]; + T loc_h = ((T *)data_sampling_loc_nram)[load_loc_weight_offset * 2 + 1]; + T weight = ((T *)data_attn_weight_nram)[load_loc_weight_offset]; + T x = loc_w * spatial_w - 0.5; + T y = loc_h * spatial_h - 0.5; + if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) { + loadNeighborPointsData( + (T *)data_value_ptr, (T *)ping_data_value_p1_nram, + (T *)ping_data_value_p2_nram, (T *)ping_data_value_p3_nram, + (T *)ping_data_value_p4_nram, span_num_deal, spatial_w, spatial_h, + num_heads, channels, x, y, head_idx); + } + T spatial_h_next_point = 0; + T spatial_w_next_point = 0; + T weight_next_point = 0; + T x_next_point = 0; + T y_next_point = 0; + __asm__ volatile("sync;"); + + for (int32_t level_idx = 0; level_idx < num_levels; ++level_idx) { + for (int32_t point_idx = 0; point_idx < num_points; ++point_idx) { + // load data + if (point_idx == num_points - 1 && level_idx == num_levels - 1) { + // last point no need to load data, continue to compute + } else if (point_idx == num_points - 1) { + const int32_t level_start_id = + ((int32_t *)data_level_start_index_nram)[level_idx + 1]; + const int32_t spatial_h_ptr = (level_idx + 1) << 1; + spatial_h_next_point = + ((int32_t *)data_spatial_shapes_nram)[spatial_h_ptr]; + spatial_w_next_point = + ((int32_t *)data_spatial_shapes_nram)[spatial_h_ptr + 1]; + data_value_ptr = data_value_gdram_start + + (level_start_id * num_heads * channels + + c_seg_idx * span_num_deal) * + sizeof(T); + loc_w = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset + + level_idx * num_points + + point_idx + 1) * + 2]; + loc_h = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset + + level_idx * num_points + + point_idx + 1) * + 2 + + 1]; + weight_next_point = + ((T *)data_attn_weight_nram)[load_loc_weight_offset + + level_idx * num_points + + point_idx + 1]; + x_next_point = loc_w * spatial_w_next_point - 0.5; + y_next_point = loc_h * spatial_h_next_point - 0.5; + if (y_next_point > -1 && x_next_point > -1 && + y_next_point < spatial_h_next_point && + x_next_point < spatial_w_next_point) { + loadNeighborPointsData( + (T *)data_value_ptr, + (T *)(ping_data_value_p1_nram + + ((level_idx * num_points + point_idx + 1) % 2) * + ping_pong_gap), + (T *)(ping_data_value_p2_nram + + ((level_idx * num_points + point_idx + 1) % 2) * + ping_pong_gap), + (T *)(ping_data_value_p3_nram + + ((level_idx * num_points + point_idx + 1) % 2) * + ping_pong_gap), + (T *)(ping_data_value_p4_nram + + ((level_idx * num_points + point_idx + 1) % 2) * + ping_pong_gap), + span_num_deal, spatial_w_next_point, spatial_h_next_point, + num_heads, channels, x_next_point, y_next_point, head_idx); + } + } else { + spatial_h_next_point = spatial_h; + spatial_w_next_point = spatial_w; + loc_w = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset + + level_idx * num_points + + point_idx + 1) * + 2]; + loc_h = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset + + level_idx * num_points + + point_idx + 1) * + 2 + + 1]; + weight_next_point = + ((T *)data_attn_weight_nram)[load_loc_weight_offset + + level_idx * num_points + + point_idx + 1]; + x_next_point = loc_w * spatial_w - 0.5; + y_next_point = loc_h * spatial_h - 0.5; + if (y_next_point > -1 && x_next_point > -1 && + y_next_point < spatial_h && x_next_point < spatial_w) { + loadNeighborPointsData( + (T *)data_value_ptr, + (T *)(ping_data_value_p1_nram + + ((level_idx * num_points + point_idx + 1) % 2) * + ping_pong_gap), + (T *)(ping_data_value_p2_nram + + ((level_idx * num_points + point_idx + 1) % 2) * + ping_pong_gap), + (T *)(ping_data_value_p3_nram + + ((level_idx * num_points + point_idx + 1) % 2) * + ping_pong_gap), + (T *)(ping_data_value_p4_nram + + ((level_idx * num_points + point_idx + 1) % 2) * + ping_pong_gap), + span_num_deal, spatial_w, spatial_h, num_heads, channels, + x_next_point, y_next_point, head_idx); + } + } + + // compute + if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) { + computeMsDeformAttn( + (T *)(ping_data_value_p1_nram + + ((level_idx * num_points + point_idx) % 2) * + ping_pong_gap), + (T *)(ping_data_value_p2_nram + + ((level_idx * num_points + point_idx) % 2) * + ping_pong_gap), + (T *)(ping_data_value_p3_nram + + ((level_idx * num_points + point_idx) % 2) * + ping_pong_gap), + (T *)(ping_data_value_p4_nram + + ((level_idx * num_points + point_idx) % 2) * + ping_pong_gap), + (T *)auxiliary_a, (T *)auxiliary_b, + (T *)(ping_data_col_nram + + data_col_ping_pong_idx * ping_pong_gap), + weight, span_num_deal, spatial_w, spatial_h, x, y); + } + + spatial_w = spatial_w_next_point; + spatial_h = spatial_h_next_point; + weight = weight_next_point; + x = x_next_point; + y = y_next_point; + __asm__ volatile("sync;"); + } + } + // store + __memcpy_async( + data_col_gdram_start + c_seg_idx * span_num_deal * sizeof(T), + ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap, + span_num_deal * sizeof(T), NRAM2GDRAM); + data_col_ping_pong_idx = (data_col_ping_pong_idx + 1) % 2; + } + + if (channels_rem > 0) { +#if __BANG_ARCH__ >= 322 + __bang_write_value( + (T *)(ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap), + channels_rem, (T)0); +#else + __bang_write_value( + (T *)(ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap), + channels_align_rem, (T)0); +#endif + // load data + // level_idx = 0, point_idx = 0 + int32_t spatial_h = ((int32_t *)data_spatial_shapes_nram)[0]; + int32_t spatial_w = ((int32_t *)data_spatial_shapes_nram)[1]; + const char *data_value_ptr = + data_value_gdram_start + channels_seg_num * span_num_deal * sizeof(T); + T loc_w = ((T *)data_sampling_loc_nram)[load_loc_weight_offset * 2]; + T loc_h = ((T *)data_sampling_loc_nram)[load_loc_weight_offset * 2 + 1]; + T weight = ((T *)data_attn_weight_nram)[load_loc_weight_offset]; + T x = loc_w * spatial_w - 0.5; + T y = loc_h * spatial_h - 0.5; + if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) { + loadNeighborPointsData( + (T *)data_value_ptr, (T *)ping_data_value_p1_nram, + (T *)ping_data_value_p2_nram, (T *)ping_data_value_p3_nram, + (T *)ping_data_value_p4_nram, channels_rem, spatial_w, spatial_h, + num_heads, channels, x, y, head_idx); + } + T spatial_h_next_point = 0; + T spatial_w_next_point = 0; + T weight_next_point = 0; + T x_next_point = 0; + T y_next_point = 0; + __asm__ volatile("sync;"); + + for (int32_t level_idx = 0; level_idx < num_levels; ++level_idx) { + for (int32_t point_idx = 0; point_idx < num_points; ++point_idx) { + // load data + if (point_idx == num_points - 1 && level_idx == num_levels - 1) { + // last point no need to load data, continue to compute + } else if (point_idx == num_points - 1) { + const int32_t level_start_id = + ((int32_t *)data_level_start_index_nram)[level_idx + 1]; + const int32_t spatial_h_ptr = (level_idx + 1) << 1; + spatial_h_next_point = + ((int32_t *)data_spatial_shapes_nram)[spatial_h_ptr]; + spatial_w_next_point = + ((int32_t *)data_spatial_shapes_nram)[spatial_h_ptr + 1]; + data_value_ptr = data_value_gdram_start + + (level_start_id * num_heads * channels + + channels_seg_num * span_num_deal) * + sizeof(T); + loc_w = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset + + level_idx * num_points + + point_idx + 1) * + 2]; + loc_h = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset + + level_idx * num_points + + point_idx + 1) * + 2 + + 1]; + weight_next_point = + ((T *)data_attn_weight_nram)[load_loc_weight_offset + + level_idx * num_points + + point_idx + 1]; + x_next_point = loc_w * spatial_w_next_point - 0.5; + y_next_point = loc_h * spatial_h_next_point - 0.5; + if (y_next_point > -1 && x_next_point > -1 && + y_next_point < spatial_h_next_point && + x_next_point < spatial_w_next_point) { + loadNeighborPointsData( + (T *)data_value_ptr, + (T *)(ping_data_value_p1_nram + + ((level_idx * num_points + point_idx + 1) % 2) * + ping_pong_gap), + (T *)(ping_data_value_p2_nram + + ((level_idx * num_points + point_idx + 1) % 2) * + ping_pong_gap), + (T *)(ping_data_value_p3_nram + + ((level_idx * num_points + point_idx + 1) % 2) * + ping_pong_gap), + (T *)(ping_data_value_p4_nram + + ((level_idx * num_points + point_idx + 1) % 2) * + ping_pong_gap), + channels_rem, spatial_w_next_point, spatial_h_next_point, + num_heads, channels, x_next_point, y_next_point, head_idx); + } + } else { + spatial_w_next_point = spatial_w; + spatial_h_next_point = spatial_h; + loc_w = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset + + level_idx * num_points + + point_idx + 1) * + 2]; + loc_h = ((T *)data_sampling_loc_nram)[(load_loc_weight_offset + + level_idx * num_points + + point_idx + 1) * + 2 + + 1]; + weight_next_point = + ((T *)data_attn_weight_nram)[load_loc_weight_offset + + level_idx * num_points + + point_idx + 1]; + x_next_point = loc_w * spatial_w - 0.5; + y_next_point = loc_h * spatial_h - 0.5; + if (y_next_point > -1 && x_next_point > -1 && + y_next_point < spatial_h && x_next_point < spatial_w) { + loadNeighborPointsData( + (T *)data_value_ptr, + (T *)(ping_data_value_p1_nram + + ((level_idx * num_points + point_idx + 1) % 2) * + ping_pong_gap), + (T *)(ping_data_value_p2_nram + + ((level_idx * num_points + point_idx + 1) % 2) * + ping_pong_gap), + (T *)(ping_data_value_p3_nram + + ((level_idx * num_points + point_idx + 1) % 2) * + ping_pong_gap), + (T *)(ping_data_value_p4_nram + + ((level_idx * num_points + point_idx + 1) % 2) * + ping_pong_gap), + channels_rem, spatial_w, spatial_h, num_heads, channels, + x_next_point, y_next_point, head_idx); + } + } + + // compute + if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) { +#if __BANG_ARCH__ >= 322 + computeMsDeformAttn( (T *)(ping_data_value_p1_nram + ((level_idx * num_points + point_idx) % 2) * ping_pong_gap), @@ -513,15 +966,29 @@ __mlu_global__ void MLUKernelMsDeformAttnForward( (T *)(ping_data_value_p4_nram + ((level_idx * num_points + point_idx) % 2) * ping_pong_gap), - (T *)auxiliary_a, (T *)auxiliary_b, channels_align_rem, - spatial_w, spatial_h, x, y); - __bang_mul_scalar((T *)auxiliary_a, (T *)auxiliary_a, (T)weight, - channels_align_rem); - __bang_add((T *)(ping_data_col_nram + - data_col_ping_pong_idx * ping_pong_gap), - (T *)(ping_data_col_nram + - data_col_ping_pong_idx * ping_pong_gap), - (T *)auxiliary_a, channels_align_rem); + (T *)auxiliary_a, (T *)auxiliary_b, + (T *)(ping_data_col_nram + + data_col_ping_pong_idx * ping_pong_gap), + weight, channels_rem, spatial_w, spatial_h, x, y); +#else + computeMsDeformAttn( + (T *)(ping_data_value_p1_nram + + ((level_idx * num_points + point_idx) % 2) * + ping_pong_gap), + (T *)(ping_data_value_p2_nram + + ((level_idx * num_points + point_idx) % 2) * + ping_pong_gap), + (T *)(ping_data_value_p3_nram + + ((level_idx * num_points + point_idx) % 2) * + ping_pong_gap), + (T *)(ping_data_value_p4_nram + + ((level_idx * num_points + point_idx) % 2) * + ping_pong_gap), + (T *)auxiliary_a, (T *)auxiliary_b, + (T *)(ping_data_col_nram + + data_col_ping_pong_idx * ping_pong_gap), + weight, channels_align_rem, spatial_w, spatial_h, x, y); +#endif } spatial_w = spatial_w_next_point; @@ -539,12 +1006,36 @@ __mlu_global__ void MLUKernelMsDeformAttnForward( channels_rem * sizeof(T), NRAM2GDRAM); data_col_ping_pong_idx = (data_col_ping_pong_idx + 1) % 2; } + load_loc_weight_idx += 1; } __asm__ volatile("sync;"); return; } -template __mlu_global__ void MLUKernelMsDeformAttnForward( +template __mlu_global__ void MLUKernelMsDeformAttnForwardDefault( + const char *data_value_gdram, const char *data_spatial_shapes_gdram, + const char *data_level_start_index_gdram, + const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram, + const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, + const int32_t channels, const int32_t num_levels, const int32_t num_queries, + const int32_t num_points, char *data_col_gdram); + +void KernelMsDeformAttnForwardDefault( + cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, + const cnrtDataType_t d_type, const char *data_value_gdram, + const char *data_spatial_shapes_gdram, + const char *data_level_start_index_gdram, + const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram, + const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, + const int32_t channels, const int32_t num_levels, const int32_t num_queries, + const int32_t num_points, char *data_col_gdram) { + MLUKernelMsDeformAttnForwardDefault<<>>( + data_value_gdram, data_spatial_shapes_gdram, data_level_start_index_gdram, + data_sampling_loc_gdram, data_attn_weight_gdram, batch_size, num_keys, + num_heads, channels, num_levels, num_queries, num_points, data_col_gdram); +} + +template __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel( const char *data_value_gdram, const char *data_spatial_shapes_gdram, const char *data_level_start_index_gdram, const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram, @@ -552,7 +1043,7 @@ template __mlu_global__ void MLUKernelMsDeformAttnForward( const int32_t channels, const int32_t num_levels, const int32_t num_queries, const int32_t num_points, char *data_col_gdram); -void KernelMsDeformAttnForward( +void KernelMsDeformAttnForwardSmallChannel( cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, const cnrtDataType_t d_type, const char *data_value_gdram, const char *data_spatial_shapes_gdram, @@ -561,7 +1052,7 @@ void KernelMsDeformAttnForward( const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, const int32_t channels, const int32_t num_levels, const int32_t num_queries, const int32_t num_points, char *data_col_gdram) { - MLUKernelMsDeformAttnForward<<>>( + MLUKernelMsDeformAttnForwardSmallChannel<<>>( data_value_gdram, data_spatial_shapes_gdram, data_level_start_index_gdram, data_sampling_loc_gdram, data_attn_weight_gdram, batch_size, num_keys, num_heads, channels, num_levels, num_queries, num_points, data_col_gdram); @@ -584,15 +1075,15 @@ void __mlu_func__ msDeformAttnCol2imBilinear( int32_t offset1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; __memcpy(grad_output_nram, data_value_ptr + offset1, deal_num_real * sizeof(T), GDRAM2NRAM); - __bang_mul_scalar(grad_weight, grad_output_nram, hw, deal_num); - __bang_sub(grad_h_weight, grad_h_weight, grad_weight, deal_num); - __bang_mul_scalar(grad_weight, grad_output_nram, hh, deal_num); - __bang_sub(grad_w_weight, grad_w_weight, grad_weight, deal_num); + __bang_mul_scalar(grad_weight, grad_output_nram, hw, deal_num_real); + __bang_sub(grad_h_weight, grad_h_weight, grad_weight, deal_num_real); + __bang_mul_scalar(grad_weight, grad_output_nram, hh, deal_num_real); + __bang_sub(grad_w_weight, grad_w_weight, grad_weight, deal_num_real); - __bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num); - __bang_mul_scalar(top_grad_temp, top_grad_temp, w1, deal_num); + __bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real); + __bang_mul_scalar(top_grad_temp, top_grad_temp, w1, deal_num_real); // for calc grad_attn_weight - __bang_mul_scalar(grad_output_nram, grad_output_nram, w1, deal_num); + __bang_mul_scalar(grad_output_nram, grad_output_nram, w1, deal_num_real); __bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset1), (T *)top_grad_temp, deal_num_real); } @@ -600,18 +1091,18 @@ void __mlu_func__ msDeformAttnCol2imBilinear( int32_t offset2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; __memcpy(grad_output_nram_temp, data_value_ptr + offset2, deal_num_real * sizeof(T), GDRAM2NRAM); - __bang_mul_scalar(grad_weight, grad_output_nram_temp, lw, deal_num); - __bang_sub(grad_h_weight, grad_h_weight, grad_weight, deal_num); - __bang_mul_scalar(grad_weight, grad_output_nram_temp, hh, deal_num); - __bang_add(grad_w_weight, grad_w_weight, grad_weight, deal_num); + __bang_mul_scalar(grad_weight, grad_output_nram_temp, lw, deal_num_real); + __bang_sub(grad_h_weight, grad_h_weight, grad_weight, deal_num_real); + __bang_mul_scalar(grad_weight, grad_output_nram_temp, hh, deal_num_real); + __bang_add(grad_w_weight, grad_w_weight, grad_weight, deal_num_real); - __bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num); - __bang_mul_scalar(top_grad_temp, top_grad_temp, w2, deal_num); + __bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real); + __bang_mul_scalar(top_grad_temp, top_grad_temp, w2, deal_num_real); __bang_mul_scalar(grad_output_nram_temp, grad_output_nram_temp, w2, - deal_num); + deal_num_real); __bang_add(grad_output_nram, grad_output_nram, grad_output_nram_temp, - deal_num); + deal_num_real); __bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset2), (T *)top_grad_temp, deal_num_real); } @@ -619,18 +1110,18 @@ void __mlu_func__ msDeformAttnCol2imBilinear( int32_t offset3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; __memcpy(grad_output_nram_temp, data_value_ptr + offset3, deal_num_real * sizeof(T), GDRAM2NRAM); - __bang_mul_scalar(grad_weight, grad_output_nram_temp, hw, deal_num); - __bang_add(grad_h_weight, grad_h_weight, grad_weight, deal_num); - __bang_mul_scalar(grad_weight, grad_output_nram_temp, lh, deal_num); - __bang_sub(grad_w_weight, grad_w_weight, grad_weight, deal_num); + __bang_mul_scalar(grad_weight, grad_output_nram_temp, hw, deal_num_real); + __bang_add(grad_h_weight, grad_h_weight, grad_weight, deal_num_real); + __bang_mul_scalar(grad_weight, grad_output_nram_temp, lh, deal_num_real); + __bang_sub(grad_w_weight, grad_w_weight, grad_weight, deal_num_real); - __bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num); - __bang_mul_scalar(top_grad_temp, top_grad_temp, w3, deal_num); + __bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real); + __bang_mul_scalar(top_grad_temp, top_grad_temp, w3, deal_num_real); // for calc grad_attn_weight __bang_mul_scalar(grad_output_nram_temp, grad_output_nram_temp, w3, - deal_num); + deal_num_real); __bang_add(grad_output_nram, grad_output_nram, grad_output_nram_temp, - deal_num); + deal_num_real); __bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset3), (T *)top_grad_temp, deal_num_real); } @@ -638,63 +1129,61 @@ void __mlu_func__ msDeformAttnCol2imBilinear( int32_t offset4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; __memcpy(grad_output_nram_temp, data_value_ptr + offset4, deal_num_real * sizeof(T), GDRAM2NRAM); - __bang_mul_scalar(grad_weight, grad_output_nram_temp, lw, deal_num); - __bang_add(grad_h_weight, grad_h_weight, grad_weight, deal_num); - __bang_mul_scalar(grad_weight, grad_output_nram_temp, lh, deal_num); - __bang_add(grad_w_weight, grad_w_weight, grad_weight, deal_num); + __bang_mul_scalar(grad_weight, grad_output_nram_temp, lw, deal_num_real); + __bang_add(grad_h_weight, grad_h_weight, grad_weight, deal_num_real); + __bang_mul_scalar(grad_weight, grad_output_nram_temp, lh, deal_num_real); + __bang_add(grad_w_weight, grad_w_weight, grad_weight, deal_num_real); - __bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num); - __bang_mul_scalar(top_grad_temp, top_grad_temp, w4, deal_num); + __bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real); + __bang_mul_scalar(top_grad_temp, top_grad_temp, w4, deal_num_real); // for calc grad_attn_weight __bang_mul_scalar(grad_output_nram_temp, grad_output_nram_temp, w4, - deal_num); + deal_num_real); __bang_add(grad_output_nram, grad_output_nram, grad_output_nram_temp, - deal_num); + deal_num_real); __bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset4), (T *)top_grad_temp, deal_num_real); } - __bang_mul(grad_output_nram, grad_output_nram, top_grad, deal_num); + __bang_mul(grad_output_nram, grad_output_nram, top_grad, deal_num_real); #if __BANG_ARCH__ >= 322 recursiveSumPool(grad_output_nram, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE); #else - const int32_t align_num_on_200 = NFU_ALIGN_SIZE / sizeof(float); + const int32_t align_num_on_200 = NFU_ALIGN_SIZE / LEN_FLOAT; recursiveSumPool(grad_output_nram, align_num_on_200, deal_num / align_num_on_200, ALIGN_NUM_FOR_REDUCE); __bang_reduce_sum(grad_output_nram, grad_output_nram, - NFU_ALIGN_SIZE / sizeof(float)); + NFU_ALIGN_SIZE / LEN_FLOAT); #endif __bang_atomic_add((T *)grad_output_nram, (T *)grad_attn_weight, (T *)grad_output_nram, 1); - __bang_mul_scalar(grad_w_weight, grad_w_weight, width, deal_num); - __bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num); - __bang_mul(grad_w_weight, grad_w_weight, top_grad_temp, deal_num); + __bang_mul_scalar(grad_w_weight, grad_w_weight, width, deal_num_real); + __bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real); + __bang_mul(grad_w_weight, grad_w_weight, top_grad_temp, deal_num_real); #if __BANG_ARCH__ >= 322 recursiveSumPool(grad_w_weight, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE); #else recursiveSumPool(grad_w_weight, align_num_on_200, deal_num / align_num_on_200, ALIGN_NUM_FOR_REDUCE); - __bang_reduce_sum(grad_w_weight, grad_w_weight, - NFU_ALIGN_SIZE / sizeof(float)); + __bang_reduce_sum(grad_w_weight, grad_w_weight, NFU_ALIGN_SIZE / LEN_FLOAT); #endif __bang_atomic_add((T *)grad_w_weight, (T *)(grad_sampling_loc), (T *)grad_w_weight, 1); - __bang_mul_scalar(grad_h_weight, grad_h_weight, height, deal_num); - __bang_mul(grad_h_weight, grad_h_weight, top_grad_temp, deal_num); + __bang_mul_scalar(grad_h_weight, grad_h_weight, height, deal_num_real); + __bang_mul(grad_h_weight, grad_h_weight, top_grad_temp, deal_num_real); #if __BANG_ARCH__ >= 322 recursiveSumPool(grad_h_weight, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE); #else recursiveSumPool(grad_h_weight, align_num_on_200, deal_num / align_num_on_200, ALIGN_NUM_FOR_REDUCE); - __bang_reduce_sum(grad_h_weight, grad_h_weight, - NFU_ALIGN_SIZE / sizeof(float)); + __bang_reduce_sum(grad_h_weight, grad_h_weight, NFU_ALIGN_SIZE / LEN_FLOAT); #endif __bang_atomic_add((T *)grad_h_weight, (T *)(grad_sampling_loc + 1), (T *)grad_h_weight, 1); } -__mlu_global__ void MLUUnion1KernelMsDeformAttnBackward( +__mlu_global__ void MLUUnion1KernelMsDeformAttnBackwarDefaultKernel( const float *data_value, const int32_t *spatial_shapes, const int32_t *data_level_start_index, const float *data_sampling_loc, const float *data_attn_weight, const float *grad_output, @@ -708,8 +1197,7 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward( const int32_t split_num = 8; const int32_t spatial_shapes_size = 64; int32_t deal_num = PAD_DOWN( - (MAX_NRAM_SIZE - spatial_shapes_size) / split_num / sizeof(float), - ALIGN_NUM); + (MAX_NRAM_SIZE - spatial_shapes_size) / split_num / LEN_FLOAT, ALIGN_NUM); float *grad_output_nram = (float *)nram_buffer; float *grad_output_nram_temp = (float *)nram_buffer + deal_num; float *grad_weight = (float *)nram_buffer + 2 * deal_num; @@ -725,10 +1213,8 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward( int32_t num_per_core = total_num / taskDim; int32_t num_rem = total_num % taskDim; num_per_core = num_per_core + int32_t(taskId < num_rem); - int32_t start_per_core = - num_rem > taskId - ? (taskId * num_per_core) - : ((num_per_core + 1) * num_rem + (taskId - num_rem) * num_per_core); + int32_t start_per_core = num_rem > taskId ? (taskId * num_per_core) + : (num_rem + taskId * num_per_core); int32_t end_per_core = start_per_core + num_per_core; const int32_t C_repeat = channels / deal_num; const int32_t C_tail = channels % deal_num; @@ -758,7 +1244,7 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward( const int32_t grad_sampling_loc_out = num_loop * num_points * 2; for (int32_t p_col = 0; p_col < num_points; ++p_col) { __memcpy(sampling_loc_nram, data_sampling_loc + data_loc_w_ptr, - 2 * sizeof(float), GDRAM2NRAM); + 2 * LEN_FLOAT, GDRAM2NRAM); const float loc_w = sampling_loc_nram[0]; const float loc_h = sampling_loc_nram[1]; const float weight = data_attn_weight[data_weight_ptr]; @@ -789,11 +1275,12 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward( for (int32_t C_loop = 0; C_loop < C_repeat; ++C_loop) { base_ptr = m_col * channels + C_loop * deal_num; - __bang_write_zero(grad_weight, 3 * deal_num); - __bang_write_zero(grad_output_nram, deal_num); + __bang_write_zero(grad_h_weight, PAD_UP(channels, ALIGN_NUM)); + __bang_write_zero(grad_w_weight, PAD_UP(channels, ALIGN_NUM)); + __bang_write_zero(grad_output_nram, PAD_UP(channels, ALIGN_NUM)); __memcpy(top_grad, grad_output + grad_output_offset + C_loop * deal_num, - deal_num * sizeof(float), GDRAM2NRAM); + deal_num * LEN_FLOAT, GDRAM2NRAM); msDeformAttnCol2imBilinear( top_grad_temp, spatial_h, spatial_w, w1, w2, w3, w4, h_low, w_low, h_high, w_high, base_ptr, h_low_ptr_offset, w_low_ptr_offset, @@ -806,10 +1293,12 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward( } if (C_tail != 0) { base_ptr = m_col * channels + C_repeat * deal_num; - __bang_write_zero(grad_output_nram, 8 * deal_num); + __bang_write_zero(grad_h_weight, PAD_UP(channels, ALIGN_NUM)); + __bang_write_zero(grad_w_weight, PAD_UP(channels, ALIGN_NUM)); + __bang_write_zero(grad_output_nram, PAD_UP(channels, ALIGN_NUM)); __memcpy(top_grad, grad_output + grad_output_offset + C_repeat * deal_num, - C_tail * sizeof(float), GDRAM2NRAM); + C_tail * LEN_FLOAT, GDRAM2NRAM); msDeformAttnCol2imBilinear( top_grad_temp, spatial_h, spatial_w, w1, w2, w3, w4, h_low, w_low, h_high, w_high, base_ptr, h_low_ptr_offset, w_low_ptr_offset, @@ -827,7 +1316,422 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward( } } -__mlu_global__ void MLUUnion1KernelMsDeformAttnBackward( +template +void __mlu_func__ +loadData(const int32_t &h_low, const int32_t &w_low, const int32_t &h_high, + const int32_t &w_high, T *grad_output_nram_tl, T *grad_output_nram_tr, + T *grad_output_nram_bl, T *grad_output_nram_br, + const T *data_value_ptr, const int32_t &width, const int32_t &height, + const int32_t &deal_num_real, const int32_t &h_low_ptr_offset, + const int32_t &w_low_ptr_offset, const int32_t &w_high_ptr_offset, + const int32_t &h_high_ptr_offset, const int32_t &base_ptr) { +#if __BANG_ARCH__ > 322 + if (h_low >= 0 && w_low >= 0) + + { + int32_t offset1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + __memcpy_async(grad_output_nram_tl, data_value_ptr + offset1, + deal_num_real * sizeof(T), GDRAM2NRAM); + } + if (h_low >= 0 && w_high <= width - 1) + + { + int32_t offset2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + __memcpy_async(grad_output_nram_tr, data_value_ptr + offset2, + deal_num_real * sizeof(T), GDRAM2NRAM); + } + if (h_high <= height - 1 && w_low >= 0) + + { + int32_t offset3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + __memcpy_async(grad_output_nram_bl, data_value_ptr + offset3, + deal_num_real * sizeof(T), GDRAM2NRAM); + } + if (h_high <= height - 1 && w_high <= width - 1) + + { + int32_t offset4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + __memcpy_async(grad_output_nram_br, data_value_ptr + offset4, + deal_num_real * sizeof(T), GDRAM2NRAM); + } + __sync_io(); +#endif +} + +template +void __mlu_func__ computeData( + const int32_t &h_low, const int32_t &w_low, const int32_t &h_high, + const int32_t &w_high, T *grad_output_nram_tl, T *grad_output_nram_tr, + T *grad_output_nram_bl, T *grad_output_nram_br, T *grad_output_nram_tl_temp, + T *grad_output_nram_tr_temp, T *grad_output_nram_bl_temp, + T *grad_output_nram_br_temp, const int32_t &width, const int32_t &height, + const int32_t &deal_num_real, T *grad_h_weight, T *grad_w_weight, + T *top_grad_temp, T *top_grad, const T &data_attn_weight, const T &hw, + const T &hh, const T &lw, const T &lh, const T &w1, const T &w2, + const T &w3, const T &w4) { +#if __BANG_ARCH__ > 322 + __bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real); + if (h_low >= 0 && w_low >= 0) { + __bang_fusion(FUSION_FMA, grad_h_weight, grad_output_nram_tl, (float)(-hw), + grad_h_weight, deal_num_real, deal_num_real); + __bang_fusion(FUSION_FMA, grad_w_weight, grad_output_nram_tl, (float)(-hh), + grad_w_weight, deal_num_real, deal_num_real); + __bang_mul_scalar(grad_output_nram_tl_temp, top_grad_temp, w1, + deal_num_real); + // for calc grad_attn_weight + __bang_mul_scalar(grad_output_nram_tl, grad_output_nram_tl, w1, + deal_num_real); + } + if (h_low >= 0 && w_high <= width - 1) { + __bang_fusion(FUSION_FMA, grad_h_weight, grad_output_nram_tr, (float)(-lw), + grad_h_weight, deal_num_real, deal_num_real); + __bang_fusion(FUSION_FMA, grad_w_weight, grad_output_nram_tr, (float)(hh), + grad_w_weight, deal_num_real, deal_num_real); + __bang_mul_scalar(grad_output_nram_tr_temp, top_grad_temp, w2, + deal_num_real); + __bang_mul_scalar(grad_output_nram_tr, grad_output_nram_tr, w2, + deal_num_real); + __bang_add(grad_output_nram_tl, grad_output_nram_tl, grad_output_nram_tr, + deal_num_real); + } + if (h_high <= height - 1 && w_low >= 0) { + __bang_fusion(FUSION_FMA, grad_h_weight, grad_output_nram_bl, (float)(hw), + grad_h_weight, deal_num_real, deal_num_real); + __bang_fusion(FUSION_FMA, grad_w_weight, grad_output_nram_bl, (float)(-lh), + grad_w_weight, deal_num_real, deal_num_real); + __bang_mul_scalar(grad_output_nram_bl_temp, top_grad_temp, w3, + deal_num_real); + // for calc grad_attn_weight + __bang_mul_scalar(grad_output_nram_bl, grad_output_nram_bl, w3, + deal_num_real); + __bang_add(grad_output_nram_tl, grad_output_nram_tl, grad_output_nram_bl, + deal_num_real); + } + if (h_high <= height - 1 && w_high <= width - 1) { + __bang_fusion(FUSION_FMA, grad_h_weight, grad_output_nram_br, (float)(lw), + grad_h_weight, deal_num_real, deal_num_real); + __bang_fusion(FUSION_FMA, grad_w_weight, grad_output_nram_br, (float)(lh), + grad_w_weight, deal_num_real, deal_num_real); + __bang_mul_scalar(grad_output_nram_br_temp, top_grad_temp, w4, + deal_num_real); + // for calc grad_attn_weight + __bang_mul_scalar(grad_output_nram_br, grad_output_nram_br, w4, + deal_num_real); + __bang_add(grad_output_nram_tl, grad_output_nram_tl, grad_output_nram_br, + deal_num_real); + } + __bang_mul(grad_output_nram_tl, grad_output_nram_tl, top_grad, deal_num_real); + recursiveSumPool(grad_output_nram_tl, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE); + __bang_mul_scalar(grad_w_weight, grad_w_weight, width, deal_num_real); + __bang_mul(grad_w_weight, grad_w_weight, top_grad_temp, deal_num_real); + + recursiveSumPool(grad_w_weight, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE); + __bang_mul_scalar(grad_h_weight, grad_h_weight, height, deal_num_real); + __bang_mul(grad_h_weight, grad_h_weight, top_grad_temp, deal_num_real); + recursiveSumPool(grad_h_weight, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE); +#endif +} + +template +void __mlu_func__ storeData( + const int32_t &h_low, const int32_t &w_low, const int32_t &h_high, + const int32_t &w_high, T *grad_output_nram_tl, T *grad_output_nram_tl_temp, + T *grad_output_nram_tr_temp, T *grad_output_nram_bl_temp, + T *grad_output_nram_br_temp, const int32_t &width, const int32_t &height, + const int32_t &deal_num_real, const int32_t &h_low_ptr_offset, + const int32_t &w_low_ptr_offset, const int32_t &w_high_ptr_offset, + const int32_t &h_high_ptr_offset, const int32_t &base_ptr, T *grad_value, + T *grad_w_weight, T *grad_h_weight, T *grad_sampling_loc, + T *grad_attn_weight) { +#if __BANG_ARCH__ > 322 + if (h_low >= 0 && w_low >= 0) + + { + int32_t offset1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + __bang_atomic_add((T *)grad_output_nram_tl_temp, + (T *)(grad_value + offset1), + (T *)grad_output_nram_tl_temp, deal_num_real); + } + if (h_low >= 0 && w_high <= width - 1) + + { + int32_t offset2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + __bang_atomic_add((T *)grad_output_nram_tr_temp, + (T *)(grad_value + offset2), + (T *)grad_output_nram_tr_temp, deal_num_real); + } + if (h_high <= height - 1 && w_low >= 0) + + { + int32_t offset3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + __bang_atomic_add((T *)grad_output_nram_bl_temp, + (T *)(grad_value + offset3), + (T *)grad_output_nram_bl_temp, deal_num_real); + } + if (h_high <= height - 1 && w_high <= width - 1) + + { + int32_t offset4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + __bang_atomic_add((T *)grad_output_nram_br_temp, + (T *)(grad_value + offset4), + (T *)grad_output_nram_br_temp, deal_num_real); + } + __bang_atomic_add((T *)grad_output_nram_tl, (T *)grad_attn_weight, + (T *)grad_output_nram_tl, 1); + __bang_atomic_add((T *)grad_w_weight, (T *)(grad_sampling_loc), + (T *)grad_w_weight, 1); + __bang_atomic_add((T *)grad_h_weight, (T *)(grad_sampling_loc + 1), + (T *)grad_h_weight, 1); +#endif +} + +template +void __mlu_func__ msDeformAttnCol2imBilinearSmallChannels( + T *top_grad_temp, const int32_t &height, const int32_t &width, const T &w1, + const T &w2, const T &w3, const T &w4, const int32_t &h_low, + const int32_t &w_low, const int32_t &h_high, const int32_t &w_high, + const int32_t &base_ptr, const int32_t &h_low_ptr_offset, + const int32_t &w_low_ptr_offset, const int32_t &h_high_ptr_offset, + const int32_t &w_high_ptr_offset, const T &hh, const T &hw, const T &lh, + const T &lw, T *top_grad, const T &data_attn_weight, T *grad_h_weight, + T *grad_w_weight, T *grad_value, T *grad_output_nram_tl, + T *grad_output_nram_tr, T *grad_output_nram_bl, T *grad_output_nram_br, + T *grad_output_nram_tl_temp, T *grad_output_nram_tr_temp, + T *grad_output_nram_bl_temp, T *grad_output_nram_br_temp, + T *grad_sampling_loc, T *grad_attn_weight, const int32_t &deal_num_real, + const T *data_value_ptr) + +{ + loadData(h_low, w_low, h_high, w_high, grad_output_nram_tl, + grad_output_nram_tr, grad_output_nram_bl, grad_output_nram_br, + data_value_ptr, width, height, deal_num_real, h_low_ptr_offset, + w_low_ptr_offset, w_high_ptr_offset, h_high_ptr_offset, base_ptr); + computeData(h_low, w_low, h_high, w_high, grad_output_nram_tl, + grad_output_nram_tr, grad_output_nram_bl, grad_output_nram_br, + grad_output_nram_tl_temp, grad_output_nram_tr_temp, + grad_output_nram_bl_temp, grad_output_nram_br_temp, width, height, + deal_num_real, grad_h_weight, grad_w_weight, top_grad_temp, + top_grad, data_attn_weight, hw, hh, lw, lh, w1, w2, w3, w4); + storeData(h_low, w_low, h_high, w_high, grad_output_nram_tl, + grad_output_nram_tl_temp, grad_output_nram_tr_temp, + grad_output_nram_bl_temp, grad_output_nram_br_temp, width, height, + deal_num_real, h_low_ptr_offset, w_low_ptr_offset, + w_high_ptr_offset, h_high_ptr_offset, base_ptr, grad_value, + grad_w_weight, grad_h_weight, grad_sampling_loc, grad_attn_weight); +} + +template +void __mlu_func__ msDeformAttnCol2imImpl( + T *top_grad_temp, T *top_grad, T *grad_h_weight, T *grad_w_weight, + T *grad_value, T *grad_output_nram_tl, T *grad_output_nram_tr, + T *grad_output_nram_bl, T *grad_output_nram_br, T *grad_output_nram_tl_temp, + T *grad_output_nram_tr_temp, T *grad_output_nram_bl_temp, + T *grad_output_nram_br_temp, T *grad_sampling_loc, T *grad_attn_weight, + T *nram_sampling_loc, T *nram_attn_weight, const int32_t &load_num, + const int32_t &tail, const int32_t &i_repeat, const int32_t &num_points, + const int32_t &start_per_core, const int32_t &num_levels, + const int32_t &num_heads, const int32_t &num_query, + const int32_t &spatial_size, const int32_t &qid_stride, + int32_t *level_start_index_nram, const int32_t &channels, + const T *data_value, const T *grad_output, int32_t *spatial_shapes_nram) { +#if __BANG_ARCH__ > 322 + int32_t weight_pos = 0; + int32_t sampling_loc_pos = 0; + for (int32_t p = 0; p < tail; ++p) { + int32_t grid_offset = start_per_core + i_repeat * load_num + p; + const int32_t l_col = grid_offset % num_levels; + const int32_t m_col = grid_offset / num_levels % num_heads; + const int32_t q_col = grid_offset / num_levels / num_heads % num_query; + const int32_t b_col = grid_offset / num_query / num_heads / num_levels; + const int32_t value_offset = b_col * spatial_size * qid_stride; + const int32_t level_start_id = level_start_index_nram[l_col]; + const int32_t grad_attn_weight_out = grid_offset * num_points; + const int32_t spatial_h_ptr = l_col << 1; + const int32_t grad_output_offset = + b_col * num_query * qid_stride + q_col * qid_stride + m_col * channels; + __memcpy(top_grad, grad_output + grad_output_offset, channels * LEN_FLOAT, + GDRAM2NRAM); + const int32_t spatial_h = spatial_shapes_nram[spatial_h_ptr]; + const int32_t spatial_w = spatial_shapes_nram[spatial_h_ptr + 1]; + const int32_t h_stride = spatial_w * qid_stride; + const int32_t value_ptr_offset = value_offset + level_start_id * qid_stride; + const float *data_value_ptr = data_value + value_ptr_offset; + float *grad_value_ptr = grad_value + value_ptr_offset; + const int32_t grad_sampling_loc_out = grid_offset * num_points << 1; + + for (int32_t p_col = 0; p_col < num_points; ++p_col) { + const float loc_w = nram_sampling_loc[sampling_loc_pos]; + const float loc_h = nram_sampling_loc[sampling_loc_pos + 1]; + const float weight = nram_attn_weight[weight_pos]; + const float h_im = loc_h * spatial_h - 0.5; + const float w_im = loc_w * spatial_w - 0.5; + + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { + const int32_t h_low = floorf(h_im); + const int32_t w_low = floorf(w_im); + const int32_t h_high = h_low + 1; + const int32_t w_high = w_low + 1; + const float lh = h_im - h_low; + const float lw = w_im - w_low; + const float hh = 1.0 - lh; + const float hw = 1.0 - lw; + const int32_t h_low_ptr_offset = h_low * h_stride; + const int32_t h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int32_t w_low_ptr_offset = w_low * qid_stride; + const int32_t w_high_ptr_offset = w_low_ptr_offset + qid_stride; + const float w1 = hh * hw; + const float w2 = hh * lw; + const float w3 = lh * hw; + const float w4 = lh * lw; + const int32_t base_ptr = m_col * channels; + __bang_write_zero(grad_h_weight, PAD_UP(channels, ALIGN_NUM)); + __bang_write_zero(grad_w_weight, PAD_UP(channels, ALIGN_NUM)); + __bang_write_zero(grad_output_nram_tl, PAD_UP(channels, ALIGN_NUM)); + msDeformAttnCol2imBilinearSmallChannels( + top_grad_temp, spatial_h, spatial_w, w1, w2, w3, w4, h_low, w_low, + h_high, w_high, base_ptr, h_low_ptr_offset, w_low_ptr_offset, + h_high_ptr_offset, w_high_ptr_offset, hh, hw, lh, lw, top_grad, + weight, grad_h_weight, grad_w_weight, grad_value_ptr, + grad_output_nram_tl, grad_output_nram_tr, grad_output_nram_bl, + grad_output_nram_br, grad_output_nram_tl_temp, + grad_output_nram_tr_temp, grad_output_nram_bl_temp, + grad_output_nram_br_temp, + grad_sampling_loc + grad_sampling_loc_out + (p_col << 1), + grad_attn_weight + grad_attn_weight_out + p_col, channels, + data_value_ptr); + } + weight_pos += 1; + sampling_loc_pos += 2; + } + } +#endif +} + +__mlu_global__ void MLUUnion1KernelMsDeformAttnBackwardSmallChannelsKernel( + const float *data_value, const int32_t *spatial_shapes, + const int32_t *data_level_start_index, const float *data_sampling_loc, + const float *data_attn_weight, const float *grad_output, + const int32_t batch, const int32_t spatial_size, const int32_t num_heads, + const int32_t channels, const int32_t num_levels, const int32_t num_query, + const int32_t num_points, float *grad_value, float *grad_sampling_loc, + float *grad_attn_weight) { +#if __BANG_ARCH__ > 322 + const int32_t split_num = 12; + const int32_t C_align = PAD_UP(channels, ALIGN_NUM); + float *grad_output_nram_tl = (float *)nram_buffer; + float *grad_output_nram_tr = (float *)nram_buffer + C_align; + float *grad_output_nram_bl = (float *)nram_buffer + 2 * C_align; + float *grad_output_nram_br = (float *)nram_buffer + 3 * C_align; + + float *grad_output_nram_tl_temp = (float *)nram_buffer + 4 * C_align; + float *grad_output_nram_tr_temp = (float *)nram_buffer + 5 * C_align; + float *grad_output_nram_bl_temp = (float *)nram_buffer + 6 * C_align; + float *grad_output_nram_br_temp = (float *)nram_buffer + 7 * C_align; + float *grad_h_weight = (float *)nram_buffer + 8 * C_align; + float *grad_w_weight = (float *)nram_buffer + 9 * C_align; + float *top_grad_temp = (float *)nram_buffer + 10 * C_align; + float *top_grad = (float *)nram_buffer + 11 * C_align; + + int32_t *spatial_shapes_nram = + (int32_t *)((float *)nram_buffer + split_num * C_align); + int32_t *level_start_index_nram = + (int32_t *)(spatial_shapes_nram + PAD_UP(num_levels * 2, ALIGN_NUM)); + float *nram_remain = (float *)((int32_t *)level_start_index_nram + + PAD_UP(num_levels, ALIGN_NUM)); + + // calc load num + const int32_t weight_num2nram = + (MAX_NRAM_SIZE / LEN_FLOAT - split_num * C_align - + 3 * PAD_UP(num_levels, ALIGN_NUM)) / + 3 / num_points; + int32_t load_num = weight_num2nram; + const int32_t total_num = batch * num_query * num_heads * num_levels; + int32_t num_per_core = total_num / taskDim; + int32_t num_rem = total_num % taskDim; + num_per_core = num_per_core + int32_t(taskId < num_rem); + if (num_per_core == 0) { + return; + } + const int32_t start_per_core = num_rem > taskId + ? (taskId * num_per_core) + : (num_rem + taskId * num_per_core); + const int32_t qid_stride = num_heads * channels; + + // load spatial_shapes anddata_level_start_index to nram + __memcpy_async(spatial_shapes_nram, spatial_shapes, + num_levels * 2 * sizeof(int32_t), GDRAM2NRAM); + __memcpy_async(level_start_index_nram, data_level_start_index, + num_levels * sizeof(int32_t), GDRAM2NRAM); + + const int32_t start_l_col = start_per_core % num_levels; + const int32_t start_m_col = start_per_core / num_levels % num_heads; + const int32_t start_q_col = + start_per_core / num_levels / num_heads % num_query; + const int32_t start_b_col = + start_per_core / num_query / num_heads / num_levels; + + const int32_t repeat = num_per_core / load_num; + const int32_t tail = num_per_core % load_num; + float *nram_sampling_loc = nram_remain; + float *nram_attn_weight = nram_sampling_loc + 2 * load_num * num_points; + + const int32_t attn_weight_offset = + start_b_col * num_query * num_heads * num_levels * num_points + + start_q_col * num_heads * num_levels * num_points + + start_m_col * num_levels * num_points + start_l_col * num_points; + const int32_t sampling_loc_offset = + start_b_col * num_query * num_heads * num_levels * num_points * 2 + + start_q_col * num_heads * num_levels * num_points * 2 + + start_m_col * num_levels * num_points * 2 + start_l_col * num_points * 2; + if (repeat > 0) { + for (int32_t i_repeat = 0; i_repeat < repeat; ++i_repeat) + + { // load weight and sampling_loc to nram + __memcpy_async(nram_sampling_loc, + data_sampling_loc + sampling_loc_offset + + i_repeat * load_num * 2 * num_points, + 2 * load_num * num_points * LEN_FLOAT, GDRAM2NRAM); + __memcpy(nram_attn_weight, + data_attn_weight + attn_weight_offset + + i_repeat * load_num * num_points, + load_num * num_points * LEN_FLOAT, GDRAM2NRAM); + msDeformAttnCol2imImpl( + top_grad_temp, top_grad, grad_h_weight, grad_w_weight, grad_value, + grad_output_nram_tl, grad_output_nram_tr, grad_output_nram_bl, + grad_output_nram_br, grad_output_nram_tl_temp, + grad_output_nram_tr_temp, grad_output_nram_bl_temp, + grad_output_nram_br_temp, grad_sampling_loc, grad_attn_weight, + nram_sampling_loc, nram_attn_weight, load_num, load_num, i_repeat, + num_points, start_per_core, num_levels, num_heads, num_query, + spatial_size, qid_stride, level_start_index_nram, channels, + data_value, grad_output, spatial_shapes_nram); + } + } + if (tail > 0) + + { // load weight and sampling_loc to nram + __memcpy_async(nram_sampling_loc, + data_sampling_loc + sampling_loc_offset + + repeat * load_num * 2 * num_points, + tail * num_points * 2 * LEN_FLOAT, GDRAM2NRAM); + __memcpy( + nram_attn_weight, + data_attn_weight + attn_weight_offset + repeat * load_num * num_points, + tail * num_points * LEN_FLOAT, GDRAM2NRAM); + msDeformAttnCol2imImpl( + top_grad_temp, top_grad, grad_h_weight, grad_w_weight, grad_value, + grad_output_nram_tl, grad_output_nram_tr, grad_output_nram_bl, + grad_output_nram_br, grad_output_nram_tl_temp, grad_output_nram_tr_temp, + grad_output_nram_bl_temp, grad_output_nram_br_temp, grad_sampling_loc, + grad_attn_weight, nram_sampling_loc, nram_attn_weight, load_num, tail, + repeat, num_points, start_per_core, num_levels, num_heads, num_query, + spatial_size, qid_stride, level_start_index_nram, channels, data_value, + grad_output, spatial_shapes_nram); + } +#endif +} + +__mlu_global__ void MLUUnion1KernelMsDeformAttnBackwarDefaultKernel( const float *data_value, const int32_t *spatial_shapes, const int32_t *data_level_start_index, const float *data_sampling_loc, const float *data_attn_weight, const float *grad_output, @@ -835,8 +1739,32 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward( const int32_t channels, const int32_t num_levels, const int32_t num_query, const int32_t num_points, float *grad_value, float *grad_sampling_loc, float *grad_attn_weight); +__mlu_global__ void MLUUnion1KernelMsDeformAttnBackwardSmallChannelsKernel( + const float *data_value, const int32_t *spatial_shapes, + const int32_t *data_level_start_index, const float *data_sampling_loc, + const float *data_attn_weight, const float *grad_output, + const int32_t batch, const int32_t spatial_size, const int32_t num_heads, + const int32_t channels, const int32_t num_levels, const int32_t num_query, + const int32_t num_points, float *grad_value, float *grad_sampling_loc, + float *grad_attn_weight); + +void KernelMsDeformAttnBackwardDefaultKernel( + cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, + const cnrtDataType_t d_type, const float *data_value, + const int32_t *spatial_shapes, const int32_t *data_level_start_index, + const float *data_sampling_loc, const float *data_attn_weight, + const float *grad_output, const int32_t batch, const int32_t spatial_size, + const int32_t num_heads, const int32_t channels, const int32_t num_levels, + const int32_t num_query, const int32_t num_points, float *grad_value, + float *grad_sampling_loc, float *grad_attn_weight) { + MLUUnion1KernelMsDeformAttnBackwarDefaultKernel<<>>( + data_value, spatial_shapes, data_level_start_index, data_sampling_loc, + data_attn_weight, grad_output, batch, spatial_size, num_heads, channels, + num_levels, num_query, num_points, grad_value, grad_sampling_loc, + grad_attn_weight); +} -void KernelMsDeformAttnBackward( +void KernelMsDeformAttnBackwardSmallChannelsKernel( cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, const cnrtDataType_t d_type, const float *data_value, const int32_t *spatial_shapes, const int32_t *data_level_start_index, @@ -845,7 +1773,8 @@ void KernelMsDeformAttnBackward( const int32_t num_heads, const int32_t channels, const int32_t num_levels, const int32_t num_query, const int32_t num_points, float *grad_value, float *grad_sampling_loc, float *grad_attn_weight) { - MLUUnion1KernelMsDeformAttnBackward<<>>( + MLUUnion1KernelMsDeformAttnBackwardSmallChannelsKernel<<>>( data_value, spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, grad_output, batch, spatial_size, num_heads, channels, num_levels, num_query, num_points, grad_value, grad_sampling_loc, diff --git a/mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp index e93fd984aa..845465ae4b 100644 --- a/mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp @@ -14,7 +14,15 @@ #define MIN(a, b) (((a) < (b)) ? (a) : (b)) -void KernelMsDeformAttnForward( +typedef enum { + MS_DEFORM_ATTN_FORWARD_INVALID = 0, /*!< Index is invalid. */ + MS_DEFORM_ATTN_FORWARD_DEFAULT = + 1, /*!< MLUKernelMsDeformAttnForwardDefault */ + MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL = + 2, /*!< MLUKernelMsDeformAttnForwardSmallChannel */ +} MsDeformAttnForwardPolicy; + +void KernelMsDeformAttnForwardDefault( cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, const cnrtDataType_t d_type, const char* data_value_gdram, const char* data_spatial_shapes_gdram, @@ -23,7 +31,37 @@ void KernelMsDeformAttnForward( const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, const int32_t channels, const int32_t num_levels, const int32_t num_queries, const int32_t num_points, char* data_col_gdram); -void KernelMsDeformAttnBackward( +void KernelMsDeformAttnForwardSmallChannel( + cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, + const cnrtDataType_t d_type, const char* data_value_gdram, + const char* data_spatial_shapes_gdram, + const char* data_level_start_index_gdram, + const char* data_sampling_loc_gdram, const char* data_attn_weight_gdram, + const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, + const int32_t channels, const int32_t num_levels, const int32_t num_queries, + const int32_t num_points, char* data_col_gdram); + +typedef enum { + MS_DEFORM_ATTN_BACKWARD_DEFAULT = 0, + MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL = 1, +} MsDeformAttnBackwardKernelPolicy; + +MsDeformAttnBackwardKernelPolicy msDeformAttnBackwardPolicyFunc( + const int32_t channels, const int32_t num_levels, + const int32_t num_points) { + const int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore); + const uint64_t max_num = nram_size / sizeof(float); + const uint64_t deal_num = + 12 * PAD_UP(channels, 32) + 3 * PAD_UP(num_levels, 32) + 3 * num_points; + + if (max_num >= deal_num) { + return MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL; + } + + return MS_DEFORM_ATTN_BACKWARD_DEFAULT; +} + +void KernelMsDeformAttnBackwardDefaultKernel( cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, const cnrtDataType_t d_type, const float* data_value, const int32_t* spatial_shapes, const int32_t* data_level_start_index, @@ -32,10 +70,23 @@ void KernelMsDeformAttnBackward( const int32_t num_heads, const int32_t channels, const int32_t num_levels, const int32_t num_queries, const int32_t num_points, float* grad_value, float* grad_sampling_loc, float* grad_attn_weight); + +void KernelMsDeformAttnBackwardSmallChannelsKernel( + cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, + const cnrtDataType_t d_type, const float* data_value, + const int32_t* spatial_shapes, const int32_t* data_level_start_index, + const float* data_sampling_loc, const float* data_attn_weight, + const float* grad_output, const int32_t batch, const int32_t spatial_size, + const int32_t num_heads, const int32_t channels, const int32_t num_levels, + const int32_t num_query, const int32_t num_points, float* grad_value, + float* grad_sampling_loc, float* grad_attn_weight); + // policy function -static void policyFuncForward(cnrtDim3_t* k_dim, cnrtFunctionType_t* k_type, - const int batch_size, const int num_queries, - const int num_heads) { +MsDeformAttnForwardPolicy msDeformAttnForwardPolicyFunc( + cnrtDim3_t* k_dim, cnrtFunctionType_t* k_type, const int32_t batch_size, + const int32_t num_keys, const int32_t num_heads, const int32_t channels, + const int32_t num_levels, const int32_t num_queries, + const int32_t num_points) { k_dim->x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); k_dim->y = MIN((batch_size * num_queries * num_heads + k_dim->x - 1) / k_dim->x, @@ -46,6 +97,15 @@ static void policyFuncForward(cnrtDim3_t* k_dim, cnrtFunctionType_t* k_type, #else *k_type = CNRT_FUNC_TYPE_UNION1; #endif + + int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore); + if (num_levels * num_points * 3 * sizeof(int32_t) > nram_size) { + return MS_DEFORM_ATTN_FORWARD_DEFAULT; + } else if (channels > nram_size / 12 / sizeof(float)) { + return MS_DEFORM_ATTN_FORWARD_DEFAULT; + } else { + return MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL; + } } // policy function for backward @@ -196,7 +256,9 @@ Tensor ms_deform_attn_mlu_forward(const Tensor& value, // calculate task dimension cnrtDim3_t k_dim; cnrtFunctionType_t k_type; - policyFuncForward(&k_dim, &k_type, batch_size, num_queries, num_heads); + MsDeformAttnForwardPolicy policy = msDeformAttnForwardPolicyFunc( + &k_dim, &k_type, batch_size, num_keys, num_heads, channels, num_levels, + num_queries, num_points); // get compute queue auto queue = torch_mlu::getCurQueue(); @@ -222,15 +284,33 @@ Tensor ms_deform_attn_mlu_forward(const Tensor& value, cnrtDataType_t data_type = torch_mlu::toCnrtDtype(value.dtype()); // launch kernel - CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnForward<<<" << k_dim.x - << ", " << k_dim.y << ", " << k_dim.z << ">>>"; - - KernelMsDeformAttnForward( - k_dim, k_type, queue, data_type, (char*)value_ptr, - (char*)spatial_shapes_ptr, (char*)level_start_index_ptr, - (char*)sampling_loc_ptr, (char*)attn_weight_ptr, batch_size, num_keys, - num_heads, channels, num_levels, num_queries, num_points, - (char*)output_ptr); + switch (policy) { + default: { + VLOG(5) << "MsDeformAttnForward Policy not supported"; + }; break; + case MS_DEFORM_ATTN_FORWARD_DEFAULT: { + CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnForwardDefault<<<" + << k_dim.x << ", " << k_dim.y << ", " << k_dim.z << ">>>"; + KernelMsDeformAttnForwardDefault( + k_dim, k_type, queue, data_type, (char*)value_ptr, + (char*)spatial_shapes_ptr, (char*)level_start_index_ptr, + (char*)sampling_loc_ptr, (char*)attn_weight_ptr, batch_size, num_keys, + num_heads, channels, num_levels, num_queries, num_points, + (char*)output_ptr); + break; + } + case MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL: { + CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnForwardSmallChannel<<<" + << k_dim.x << ", " << k_dim.y << ", " << k_dim.z << ">>>"; + KernelMsDeformAttnForwardSmallChannel( + k_dim, k_type, queue, data_type, (char*)value_ptr, + (char*)spatial_shapes_ptr, (char*)level_start_index_ptr, + (char*)sampling_loc_ptr, (char*)attn_weight_ptr, batch_size, num_keys, + num_heads, channels, num_levels, num_queries, num_points, + (char*)output_ptr); + break; + } + } output = output.view({batch_size, num_queries, num_heads * channels}); return output; @@ -391,14 +471,31 @@ void ms_deform_attn_mlu_backward( // launch kernel CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnBackward<<<" << k_dim.x << ", " << k_dim.y << ", " << k_dim.z << ">>>"; - - KernelMsDeformAttnBackward( - k_dim, k_type, queue, data_type, (float*)value_ptr, - (int32_t*)spatial_shapes_ptr, (int32_t*)level_start_index_ptr, - (float*)sampling_loc_ptr, (float*)attn_weight_ptr, - (float*)grad_output_ptr, batch_size, num_keys, num_heads, channels, - num_levels, num_queries, num_points, (float*)grad_value_ptr, - (float*)grad_sampling_loc_ptr, (float*)grad_attn_weight_ptr); + MsDeformAttnBackwardKernelPolicy kernelPolicy = + msDeformAttnBackwardPolicyFunc(channels, num_levels, num_points); + switch (kernelPolicy) { + default: { + VLOG(5) << "NotImplemented."; + } break; + case MS_DEFORM_ATTN_BACKWARD_DEFAULT: { + KernelMsDeformAttnBackwardDefaultKernel( + k_dim, k_type, queue, data_type, (float*)value_ptr, + (int32_t*)spatial_shapes_ptr, (int32_t*)level_start_index_ptr, + (float*)sampling_loc_ptr, (float*)attn_weight_ptr, + (float*)grad_output_ptr, batch_size, num_keys, num_heads, channels, + num_levels, num_queries, num_points, (float*)grad_value_ptr, + (float*)grad_sampling_loc_ptr, (float*)grad_attn_weight_ptr); + } break; + case MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL: { + KernelMsDeformAttnBackwardSmallChannelsKernel( + k_dim, k_type, queue, data_type, (float*)value_ptr, + (int32_t*)spatial_shapes_ptr, (int32_t*)level_start_index_ptr, + (float*)sampling_loc_ptr, (float*)attn_weight_ptr, + (float*)grad_output_ptr, batch_size, num_keys, num_heads, channels, + num_levels, num_queries, num_points, (float*)grad_value_ptr, + (float*)grad_sampling_loc_ptr, (float*)grad_attn_weight_ptr); + } break; + } } Tensor ms_deform_attn_impl_forward(const Tensor& value,