From d7bc6803c6332d72acd68c39c313efe48bd64f98 Mon Sep 17 00:00:00 2001 From: Zhang <33895156+ZhangLearning@users.noreply.github.com> Date: Mon, 3 Apr 2023 23:35:07 +0800 Subject: [PATCH] [Enhancement] Repalce the implementation of roiaware_pool3d with mlu-ops. (#2699) * [Feature] Repalce the implementation of roiaware_pool3d with mlu-ops. * [Feature] Repalce the implementation of roiaware_pool3d with mlu-ops. --- .../common/mlu/roiaware_pool3d_mlu_kernel.mlu | 747 ------------------ .../csrc/pytorch/mlu/roiaware_pool3d_mlu.cpp | 409 ++-------- 2 files changed, 87 insertions(+), 1069 deletions(-) delete mode 100644 mmcv/ops/csrc/common/mlu/roiaware_pool3d_mlu_kernel.mlu diff --git a/mmcv/ops/csrc/common/mlu/roiaware_pool3d_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/roiaware_pool3d_mlu_kernel.mlu deleted file mode 100644 index 4c1edf0bf5..0000000000 --- a/mmcv/ops/csrc/common/mlu/roiaware_pool3d_mlu_kernel.mlu +++ /dev/null @@ -1,747 +0,0 @@ -/************************************************************************* - * Copyright (C) 2022 Cambricon. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS - * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - *************************************************************************/ - -#include "common_mlu_helper.hpp" - -#define ROI_OFFSET 7 -#define FLOAT_NRAM_BUFFER_NUM 14 -#define HALF_NRAM_BUFFER_NUM 25 -#define ALIGN_NUM 64 - -__nram__ char data_nram[MAX_NRAM_SIZE]; - -template -__mlu_global__ void MLUUnion1KernelPtsIdxOfVoxels( - const int pool_method, const int boxes_num, const int pts_num, - const int max_pts_each_voxel, const int out_x, const int out_y, - const int out_z, const T *rois, const T *pts, int *pts_idx_of_voxels) { - // params (T)rois: (boxes_num, 7) - // params (T)pts: (3, pts_num) - // params (int)pts_idx_of_voxels: (boxes_num, out_x, out_y, out_z, - // max_pts_each_voxel) - - // make sure that memcore is not used - if (coreId == 0x80) { - return; - } - int nram_pts_num = 0; - if (sizeof(T) == sizeof(float)) { - nram_pts_num = PAD_DOWN( - (MAX_NRAM_SIZE / sizeof(float) / FLOAT_NRAM_BUFFER_NUM), ALIGN_NUM); - } else { - nram_pts_num = PAD_DOWN( - (MAX_NRAM_SIZE / sizeof(half) / HALF_NRAM_BUFFER_NUM), ALIGN_NUM); - } - - char *X = NULL; - char *Y = NULL; - char *Z = NULL; - char *local_X = NULL; - char *local_Y = NULL; - char *local_Z = NULL; - char *nram_pts_in_flag = NULL; - float *temp_buffer1 = NULL; - float *temp_buffer2 = NULL; - float *temp_buffer3 = NULL; - float *temp_buffer4 = NULL; - float *temp_buffer5 = NULL; - float *nram_voxel_offset = NULL; - int *nram_pts_idx_seq = NULL; - float *fp_local_X = NULL; - float *fp_local_Y = NULL; - float *fp_local_Z = NULL; - float *fp_nram_pts_in_flag = NULL; - if (sizeof(T) == sizeof(float)) { - X = (char *)((float *)data_nram); - Y = (char *)((float *)data_nram + nram_pts_num); - Z = (char *)((float *)data_nram + nram_pts_num * 2); - local_X = (char *)((float *)data_nram + nram_pts_num * 3); - local_Y = (char *)((float *)data_nram + nram_pts_num * 4); - local_Z = (char *)((float *)data_nram + nram_pts_num * 5); - nram_pts_in_flag = (char *)((float *)data_nram + nram_pts_num * 6); - temp_buffer1 = (float *)data_nram + nram_pts_num * 7; - temp_buffer2 = (float *)data_nram + nram_pts_num * 8; - temp_buffer3 = (float *)data_nram + nram_pts_num * 9; - temp_buffer4 = (float *)data_nram + nram_pts_num * 10; - temp_buffer5 = (float *)data_nram + nram_pts_num * 11; - nram_voxel_offset = (float *)data_nram + nram_pts_num * 12; - nram_pts_idx_seq = (int *)((float *)data_nram + nram_pts_num * 13); - fp_local_X = (float *)local_X; - fp_local_Y = (float *)local_Y; - fp_local_Z = (float *)local_Z; - fp_nram_pts_in_flag = (float *)nram_pts_in_flag; - } else { - X = (char *)((half *)data_nram); - Y = (char *)((half *)data_nram + nram_pts_num); - Z = (char *)((half *)data_nram + nram_pts_num * 2); - local_X = (char *)((half *)data_nram + nram_pts_num * 4); - local_Y = (char *)((half *)data_nram + nram_pts_num * 6); - local_Z = (char *)((half *)data_nram + nram_pts_num * 8); - nram_pts_in_flag = (char *)((half *)data_nram + nram_pts_num * 10); - temp_buffer1 = (float *)((half *)data_nram + nram_pts_num * 11); - temp_buffer2 = (float *)((half *)data_nram + nram_pts_num * 13); - temp_buffer3 = (float *)((half *)data_nram + nram_pts_num * 15); - temp_buffer4 = (float *)((half *)data_nram + nram_pts_num * 17); - temp_buffer5 = (float *)((half *)data_nram + nram_pts_num * 19); - nram_voxel_offset = (float *)((half *)data_nram + nram_pts_num * 21); - nram_pts_idx_seq = (int *)((half *)data_nram + nram_pts_num * 23); - fp_local_X = (float *)((half *)local_X - nram_pts_num); - fp_local_Y = (float *)((half *)local_Y - nram_pts_num); - fp_local_Z = (float *)((half *)local_Z - nram_pts_num); - fp_nram_pts_in_flag = (float *)((half *)nram_pts_in_flag - nram_pts_num); - } - - for (int i = 0; i < nram_pts_num; i++) { - nram_pts_idx_seq[i] = i; - } - - int nram_pts_loop_times = pts_num / nram_pts_num; - int rem_nram_num = pts_num % nram_pts_num; - - for (int roi_index = taskId; roi_index < boxes_num; roi_index += taskDim) { - const T *cur_roi = rois + roi_index * ROI_OFFSET; - T cx = cur_roi[0]; - T cy = cur_roi[1]; - T cz = cur_roi[2]; - T dx = cur_roi[3]; - T dy = cur_roi[4]; - T dz = cur_roi[5]; - T rz = cur_roi[6]; - - T dx_2 = dx / 2.0; - T dy_2 = dy / 2.0; - T dz_2 = dz / 2.0; - - for (int loop_idx = 0; loop_idx <= nram_pts_loop_times; loop_idx++) { - int load_pts_num = - (loop_idx == nram_pts_loop_times) ? rem_nram_num : nram_pts_num; - if (load_pts_num == 0) { - break; - } - int pts_offset_cur_loop = nram_pts_num * loop_idx; - int compute_pts_num = (loop_idx == nram_pts_loop_times) - ? PAD_UP(rem_nram_num, ALIGN_NUM) - : nram_pts_num; - // load pts - __memcpy((void *)X, (T *)pts + pts_offset_cur_loop, - load_pts_num * sizeof(T), GDRAM2NRAM); - __memcpy((void *)Y, (T *)pts + pts_num + pts_offset_cur_loop, - load_pts_num * sizeof(T), GDRAM2NRAM); - __memcpy((void *)Z, (T *)pts + pts_num * 2 + pts_offset_cur_loop, - load_pts_num * sizeof(T), GDRAM2NRAM); - // fabs(local_z) - __bang_sub_scalar((T *)local_Z, (T *)Z, (T)cz, compute_pts_num); - __bang_sub_scalar((T *)temp_buffer1, (T *)Z, (T)(cz + dz_2), - compute_pts_num); - __bang_active_abs((T *)temp_buffer1, (T *)temp_buffer1, compute_pts_num); -#if __BANG_ARCH__ >= 322 - __bang_le_scalar((T *)nram_pts_in_flag, (T *)temp_buffer1, (T)(dz_2), - compute_pts_num); -#else - __bang_write_value((void *)temp_buffer2, compute_pts_num, (T)(dz_2)); - __bang_le((T *)nram_pts_in_flag, (T *)temp_buffer1, (T *)temp_buffer2, - compute_pts_num); -#endif - T cosa = std::cos(-rz); - T sina = std::sin(-rz); - __bang_sub_scalar((T *)temp_buffer3, (T *)X, (T)cx, compute_pts_num); - __bang_sub_scalar((T *)temp_buffer4, (T *)Y, (T)cy, compute_pts_num); - __bang_mul_scalar((T *)temp_buffer1, (T *)temp_buffer3, (T)cosa, - compute_pts_num); - __bang_mul_scalar((T *)temp_buffer2, (T *)temp_buffer4, (T)sina, - compute_pts_num); - // local_x - __bang_sub((T *)local_X, (T *)temp_buffer1, (T *)temp_buffer2, - compute_pts_num); - // fabs(local_x) - __bang_active_abs((T *)temp_buffer1, (T *)local_X, compute_pts_num); - // fabs(local_x) < dx/2 ? 1 : 0 -#if __BANG_ARCH__ >= 322 - __bang_lt_scalar((T *)temp_buffer1, (T *)temp_buffer1, (T)(dx_2), - compute_pts_num); -#else - __bang_write_value((void *)temp_buffer2, compute_pts_num, (T)(dx_2)); - __bang_lt((T *)temp_buffer1, (T *)temp_buffer1, (T *)temp_buffer2, - compute_pts_num); -#endif - __bang_and((T *)nram_pts_in_flag, (T *)nram_pts_in_flag, - (T *)temp_buffer1, - compute_pts_num); // flush res - - __bang_mul_scalar((T *)temp_buffer1, (T *)temp_buffer3, (T)sina, - compute_pts_num); - __bang_mul_scalar((T *)temp_buffer2, (T *)temp_buffer4, (T)cosa, - compute_pts_num); - // local_y - __bang_add((T *)local_Y, (T *)temp_buffer1, (T *)temp_buffer2, - compute_pts_num); - // fabs(local_y) - __bang_active_abs((T *)temp_buffer1, (T *)local_Y, compute_pts_num); - // fabs(local_y) < dy/2 ? 1 : 0 -#if __BANG_ARCH__ >= 322 - __bang_lt_scalar((T *)temp_buffer1, (T *)temp_buffer1, (T)(dy_2), - compute_pts_num); -#else - __bang_write_value((void *)temp_buffer2, compute_pts_num, (T)(dy_2)); - __bang_lt((T *)temp_buffer1, (T *)temp_buffer1, (T *)temp_buffer2, - compute_pts_num); -#endif - __bang_and((T *)nram_pts_in_flag, (T *)nram_pts_in_flag, - (T *)temp_buffer1, - compute_pts_num); // flush res - T x_res = dx / out_x; - T y_res = dy / out_y; - T z_res = dz / out_z; - __bang_add_scalar((T *)local_X, (T *)local_X, (T)(dx_2), compute_pts_num); - __bang_add_scalar((T *)local_Y, (T *)local_Y, (T)(dy_2), compute_pts_num); - // local_Z do not need to add dz/2.0 - -#if (__BANG_ARCH__ >= 322) && (__BANG_ARCH__ != 372) - __bang_div((T *)local_X, (T *)local_X, (T)x_res, compute_pts_num); - __bang_div((T *)local_Y, (T *)local_Y, (T)y_res, compute_pts_num); - __bang_div((T *)local_Z, (T *)local_Z, (T)z_res, compute_pts_num); -#else - __bang_mul_scalar((T *)local_X, (T *)local_X, (T)(1 / x_res), - compute_pts_num); - __bang_mul_scalar((T *)local_Y, (T *)local_Y, (T)(1 / y_res), - compute_pts_num); - __bang_mul_scalar((T *)local_Z, (T *)local_Z, (T)(1 / z_res), - compute_pts_num); -#endif - // float = float2int + int2float, half = half2int + int2float - if (sizeof(T) == sizeof(float)) { -#if __BANG_ARCH__ >= 322 - __bang_float2int32_tz((int *)temp_buffer1, (float *)local_X, - compute_pts_num, 0); - __bang_float2int32_tz((int *)temp_buffer2, (float *)local_Y, - compute_pts_num, 0); - __bang_float2int32_tz((int *)temp_buffer3, (float *)local_Z, - compute_pts_num, 0); - __bang_int322float_rn((float *)fp_local_X, (int *)temp_buffer1, - compute_pts_num, 0); - __bang_int322float_rn((float *)fp_local_Y, (int *)temp_buffer2, - compute_pts_num, 0); - __bang_int322float_rn((float *)fp_local_Z, (int *)temp_buffer3, - compute_pts_num, 0); -#else - convertFloat2Int((int *)temp_buffer1, (float *)temp_buffer2, - (float *)fp_local_X, (float *)temp_buffer3, - compute_pts_num); - convertFloat2Int((int *)temp_buffer2, (float *)temp_buffer3, - (float *)fp_local_Y, (float *)temp_buffer4, - compute_pts_num); - convertFloat2Int((int *)temp_buffer3, (float *)temp_buffer4, - (float *)fp_local_Z, (float *)temp_buffer5, - compute_pts_num); - convertInt2Float((float *)fp_local_X, (float *)temp_buffer4, - (int *)temp_buffer1, (float *)temp_buffer5, - compute_pts_num); - convertInt2Float((float *)fp_local_Y, (float *)temp_buffer4, - (int *)temp_buffer2, (float *)temp_buffer5, - compute_pts_num); - convertInt2Float((float *)fp_local_Z, (float *)temp_buffer4, - (int *)temp_buffer3, (float *)temp_buffer5, - compute_pts_num); -#endif - } else { - __bang_half2float((float *)temp_buffer4, (half *)nram_pts_in_flag, - compute_pts_num); - __bang_move((void *)fp_nram_pts_in_flag, (void *)temp_buffer4, - compute_pts_num * sizeof(float)); -#if __BANG_ARCH__ >= 322 - __bang_half2int32_tz((int *)temp_buffer1, (half *)local_X, - compute_pts_num, 0); - __bang_half2int32_tz((int *)temp_buffer2, (half *)local_Y, - compute_pts_num, 0); - __bang_half2int32_tz((int *)temp_buffer3, (half *)local_Z, - compute_pts_num, 0); - __bang_int322float_rn((float *)fp_local_X, (int *)temp_buffer1, - compute_pts_num, 0); - __bang_int322float_rn((float *)fp_local_Y, (int *)temp_buffer2, - compute_pts_num, 0); - __bang_int322float_rn((float *)fp_local_Z, (int *)temp_buffer3, - compute_pts_num, 0); -#else - __bang_half2int16_tz((int16_t *)temp_buffer1, (half *)local_X, - compute_pts_num, 0); - __bang_half2int16_tz((int16_t *)temp_buffer2, (half *)local_Y, - compute_pts_num, 0); - __bang_half2int16_tz((int16_t *)temp_buffer3, (half *)local_Z, - compute_pts_num, 0); - __bang_int162float((float *)fp_local_X, (int16_t *)temp_buffer1, - compute_pts_num, 0); - __bang_int162float((float *)fp_local_Y, (int16_t *)temp_buffer2, - compute_pts_num, 0); - __bang_int162float((float *)fp_local_Z, (int16_t *)temp_buffer3, - compute_pts_num, 0); -#endif - } - // process index >= 0 - __bang_write_value((float *)temp_buffer4, compute_pts_num, (float)0.0f); - __bang_maxequal((float *)fp_local_X, (float *)fp_local_X, - (float *)temp_buffer4, compute_pts_num); - __bang_maxequal((float *)fp_local_Y, (float *)fp_local_Y, - (float *)temp_buffer4, compute_pts_num); - __bang_maxequal((float *)fp_local_Z, (float *)fp_local_Z, - (float *)temp_buffer4, compute_pts_num); - // process index <= (out_x - 1) - __bang_write_value((float *)temp_buffer5, compute_pts_num, - (float)(out_x - 1)); - __bang_minequal((float *)fp_local_X, (float *)fp_local_X, - (float *)temp_buffer5, compute_pts_num); - __bang_write_value((float *)temp_buffer5, compute_pts_num, - (float)(out_y - 1)); - __bang_minequal((float *)fp_local_Y, (float *)fp_local_Y, - (float *)temp_buffer5, compute_pts_num); - __bang_write_value((float *)temp_buffer5, compute_pts_num, - (float)(out_z - 1)); - __bang_minequal((float *)fp_local_Z, (float *)fp_local_Z, - (float *)temp_buffer5, compute_pts_num); - __bang_mul_scalar((float *)temp_buffer1, (float *)fp_local_X, - (float)(out_y * out_z), compute_pts_num); - __bang_mul_scalar((float *)temp_buffer2, (float *)fp_local_Y, - (float)out_z, compute_pts_num); - __bang_mul_scalar((float *)temp_buffer3, (float *)fp_local_Z, (float)1.0, - compute_pts_num); - __bang_add((float *)nram_voxel_offset, (float *)temp_buffer1, - (float *)temp_buffer2, compute_pts_num); - __bang_add((float *)nram_voxel_offset, (float *)nram_voxel_offset, - (float *)temp_buffer3, compute_pts_num); - __bang_mul_scalar((float *)nram_voxel_offset, (float *)nram_voxel_offset, - (float)max_pts_each_voxel, compute_pts_num); - if (compute_pts_num != load_pts_num) { - __memset_nram((float *)fp_nram_pts_in_flag + load_pts_num, - compute_pts_num - load_pts_num, (float)0.0); - } - __bang_collect((float *)temp_buffer4, (float *)nram_pts_idx_seq, - (float *)fp_nram_pts_in_flag, compute_pts_num); - int pts_num_in_cur_roi = - (int)__bang_count((float *)fp_nram_pts_in_flag, compute_pts_num); - int *pts_idx_cur_voxels = - (int *)pts_idx_of_voxels + - roi_index * out_x * out_y * out_z * max_pts_each_voxel; - for (int idx = 0; idx < pts_num_in_cur_roi; idx++) { - int cur_pts_idx = *((int *)temp_buffer4 + idx); - int offset = (int)(*((float *)nram_voxel_offset + cur_pts_idx)); - int cnt = pts_idx_cur_voxels[offset]; - if (cnt < max_pts_each_voxel - 1) { - pts_idx_cur_voxels[offset + cnt + 1] = - cur_pts_idx + loop_idx * nram_pts_num; - pts_idx_cur_voxels[offset]++; - } - } - } - } -} - -template -__mlu_global__ void MLUUnion1KernelRoiawarePool3dForward( - const int pool_method, const int boxes_num, const int pts_num, - const int channels, const int max_pts_each_voxel, const int out_x, - const int out_y, const int out_z, const T *pts_feature, - const int *pts_idx_of_voxels, T *pooled_features, int *argmax) { - // params (T)pts_feature: (channels, pts_num) - // params (int)pts_idx_of_voxels: (boxes_num, out_x, out_y, out_z, - // max_pts_each_voxel) params (int)argmax: (boxes_num, out_x, out_y, out_z, - // channels) params (T)pooled_features: (boxes_num, out_x, out_y, out_z, - // channels) - - // make sure that memcore is not used - if (coreId == 0x80) { - return; - } - int align_num = NFU_ALIGN_SIZE / sizeof(T); - int align_max_pts_each_voxel = PAD_UP(max_pts_each_voxel, align_num); - int nram_channels_limit = - PAD_DOWN((MAX_NRAM_SIZE - 128 - - align_max_pts_each_voxel * (sizeof(int) + sizeof(T))) / - ((align_max_pts_each_voxel + 1) * sizeof(T) + sizeof(int)), - align_num); - int *nram_pts_idx_cur_voxel = (int *)data_nram; - // nram_pts_idx_cur_voxel [align_max_pts_each_voxel] - T *nram_max_pts_feature_tmp = - (T *)((int *)nram_pts_idx_cur_voxel + align_max_pts_each_voxel); - // nram_max_pts_feature_tmp [align_max_pts_each_voxel] - T *nram_pts_feature_in_voxel = - ((T *)nram_max_pts_feature_tmp + align_max_pts_each_voxel); - // nram_pts_feature_in_voxel [nram_channels_limit, align_max_pts_each_voxel] - T *nram_pooled_features_cur_voxel = - ((T *)nram_pts_feature_in_voxel + - nram_channels_limit * align_max_pts_each_voxel); - // nram_pooled_features_cur_voxel [nram_channels_limit] - int *nram_argmax_cur_voxel = - (int *)((T *)nram_pooled_features_cur_voxel + nram_channels_limit); - // nram_argmax_cur_voxel [nram_channels_limit] - char *one_pooled_feature = - (char *)((int *)nram_argmax_cur_voxel + nram_channels_limit); - // one_pooled_feature [128] - int channels_loop_times = channels / nram_channels_limit; - int rem_channels = channels % nram_channels_limit; - for (int voxel_index = taskId; - voxel_index < boxes_num * out_x * out_y * out_z; - voxel_index += taskDim) { - int *pts_idx_cur_voxels = - (int *)pts_idx_of_voxels + voxel_index * max_pts_each_voxel; - __memcpy((void *)nram_pts_idx_cur_voxel, (void *)pts_idx_cur_voxels, - max_pts_each_voxel * sizeof(int), GDRAM2NRAM); - int pts_num_cur_voxel = nram_pts_idx_cur_voxel[0]; - if (pts_num_cur_voxel == 0) { - continue; - } - for (int channels_loop_idx = 0; channels_loop_idx <= channels_loop_times; - channels_loop_idx++) { - int actual_channels_num = (channels_loop_idx == channels_loop_times) - ? rem_channels - : nram_channels_limit; - if (actual_channels_num == 0) { - break; - } - int channels_offset = nram_channels_limit * channels_loop_idx; - -#if ((__BANG_ARCH__ >= 200) && (__BANG_ARCH__ < 300)) - int compute_channels_num = (channels_loop_idx == channels_loop_times) - ? PAD_UP(rem_channels, align_num) - : nram_channels_limit; - if (pool_method == 0) { - __bang_write_value((void *)nram_pts_feature_in_voxel, - compute_channels_num * align_max_pts_each_voxel, - (T)-INFINITY); - } -#endif - - T *pts_feature_cur_loop = (T *)pts_feature + channels_offset * pts_num; - for (int idx = 0; idx < pts_num_cur_voxel; idx++) { - __memcpy((T *)nram_pts_feature_in_voxel + idx, - (T *)pts_feature_cur_loop + nram_pts_idx_cur_voxel[idx + 1], - sizeof(T), GDRAM2NRAM, align_max_pts_each_voxel * sizeof(T), - pts_num * sizeof(T), actual_channels_num - 1); - } - for (int channel_idx = 0; channel_idx < actual_channels_num; - channel_idx++) { - if (pool_method == 0) { -#if __BANG_ARCH__ >= 322 - __bang_argmax((T *)one_pooled_feature, - (T *)nram_pts_feature_in_voxel + - channel_idx * align_max_pts_each_voxel, - pts_num_cur_voxel); - T max_val = ((T *)one_pooled_feature)[0]; - int max_idx = (int)(*(uint32_t *)((T *)one_pooled_feature + 1)); - nram_pooled_features_cur_voxel[channel_idx] = - (max_val == -INFINITY) ? 0 : max_val; - nram_argmax_cur_voxel[channel_idx] = - (max_val == -INFINITY) ? -1 : nram_pts_idx_cur_voxel[max_idx + 1]; -#else - // __bang_max need align num on mlu200 series - if (sizeof(T) == sizeof(float)) { - __bang_max((float *)one_pooled_feature, - (float *)nram_pts_feature_in_voxel + - channel_idx * align_max_pts_each_voxel, - align_max_pts_each_voxel); - float max_val = ((float *)one_pooled_feature)[0]; - __bang_write_value((void *)nram_max_pts_feature_tmp, - align_max_pts_each_voxel, (float)max_val); - __bang_eq((float *)nram_max_pts_feature_tmp, - (float *)nram_pts_feature_in_voxel + - channel_idx * align_max_pts_each_voxel, - (float *)nram_max_pts_feature_tmp, - align_max_pts_each_voxel); - int max_idx = (int)__bang_findfirst1( - (float *)nram_max_pts_feature_tmp, align_max_pts_each_voxel); - nram_pooled_features_cur_voxel[channel_idx] = - (max_val == -INFINITY) ? 0 : max_val; - nram_argmax_cur_voxel[channel_idx] = - (max_val == -INFINITY) ? -1 - : nram_pts_idx_cur_voxel[max_idx + 1]; - } else { - int max_idx = -1; - float max_val = -INFINITY; - for (int k = 0; k < pts_num_cur_voxel; k++) { - float pts_feature_cur_channel = __half2float_rd( - *((half *)nram_pts_feature_in_voxel + - channel_idx * align_max_pts_each_voxel + k)); - if (pts_feature_cur_channel > max_val) { - max_val = pts_feature_cur_channel; - max_idx = k; - } - } - nram_pooled_features_cur_voxel[channel_idx] = - (max_idx == -1) ? 0 : max_val; - nram_argmax_cur_voxel[channel_idx] = - (max_idx == -1) ? -1 : nram_pts_idx_cur_voxel[max_idx + 1]; - } -#endif - } else if (pool_method == 1) { - float sum_val_cur_channel = 0; - for (int k = 0; k < pts_num_cur_voxel; k++) { - sum_val_cur_channel += static_cast( - ((T *)nram_pts_feature_in_voxel)[channel_idx * - align_max_pts_each_voxel + - k]); - } - nram_pooled_features_cur_voxel[channel_idx] = - (T)(sum_val_cur_channel / pts_num_cur_voxel); - } - } - // store - __memcpy((T *)pooled_features + voxel_index * channels + channels_offset, - (void *)nram_pooled_features_cur_voxel, - actual_channels_num * sizeof(T), NRAM2GDRAM); - if (pool_method == 0) { - __memcpy((int *)argmax + voxel_index * channels + channels_offset, - (void *)nram_argmax_cur_voxel, - actual_channels_num * sizeof(int), NRAM2GDRAM); - } - } - } -} - -void KernelPtsIdxOfVoxels(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, - cnrtQueue_t queue, const cnrtDataType_t d_type, - const int pool_method, const int boxes_num, - const int pts_num, const int max_pts_each_voxel, - const int out_x, const int out_y, const int out_z, - const void *rois, const void *pts, - int *pts_idx_of_voxels) { - switch (d_type) { - case CNRT_FLOAT32: { - MLUUnion1KernelPtsIdxOfVoxels<<>>( - pool_method, boxes_num, pts_num, max_pts_each_voxel, out_x, out_y, - out_z, (float *)rois, (float *)pts, (int *)pts_idx_of_voxels); - }; break; - case CNRT_FLOAT16: { - MLUUnion1KernelPtsIdxOfVoxels<<>>( - pool_method, boxes_num, pts_num, max_pts_each_voxel, out_x, out_y, - out_z, (half *)rois, (half *)pts, (int *)pts_idx_of_voxels); - }; break; - default: { - break; - } - } -} - -void KernelRoiawarePool3dForward( - cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, - const cnrtDataType_t d_type, const int pool_method, const int boxes_num, - const int pts_num, const int channels, const int max_pts_each_voxel, - const int out_x, const int out_y, const int out_z, const void *pts_feature, - const int *pts_idx_of_voxels, void *pooled_features, int *argmax) { - switch (d_type) { - case CNRT_FLOAT32: { - MLUUnion1KernelRoiawarePool3dForward<<>>( - pool_method, boxes_num, pts_num, channels, max_pts_each_voxel, out_x, - out_y, out_z, (float *)pts_feature, (int *)pts_idx_of_voxels, - (float *)pooled_features, (int *)argmax); - }; break; - case CNRT_FLOAT16: { - MLUUnion1KernelRoiawarePool3dForward<<>>( - pool_method, boxes_num, pts_num, channels, max_pts_each_voxel, out_x, - out_y, out_z, (half *)pts_feature, (int *)pts_idx_of_voxels, - (half *)pooled_features, (int *)argmax); - }; break; - default: { - break; - } - } -} - -template -__mlu_global__ void MLUUnion1KernelRoiawareMaxPool3dBackward( - const int boxes_num, const int out_x, const int out_y, const int out_z, - const int channels, const int *argmax, const T *grad_out, T *grad_in) { - // params (int)argmax: (boxes_num, out_x, out_y, out_z, channels) - // params (T)grad_out: (boxes_num, out_x, out_y, out_z, channels) - // params (T)grad_in: (pts_num, channels) - - // make sure that memcore is not used - if (coreId == 0x80) { - return; - } - int nram_channels_limit = - (MAX_NRAM_SIZE - sizeof(T) * 1) / (sizeof(T) + sizeof(int)); - int *nram_argmax_cur_loop = (int *)data_nram; - // nram_argmax_cur_loop [nram_channels_limit] - T *nram_grad_out_cur_loop = - (T *)((int *)nram_argmax_cur_loop + nram_channels_limit); - // nram_grad_out_cur_loop [nram_channels_limit] - T *nram_grad_in_cur_channel = - (T *)nram_grad_out_cur_loop + nram_channels_limit; - // nram_grad_in_cur_channel [1] - int channels_loop_times = channels / nram_channels_limit; - int rem_channels = channels % nram_channels_limit; - int voxels_num = boxes_num * out_x * out_y * out_z; - - for (int voxel_index = taskId; voxel_index < voxels_num; - voxel_index += taskDim) { - const int *argmax_cur_voxel = argmax + voxel_index * channels; - const T *grad_out_cur_voxel = grad_out + voxel_index * channels; - - for (int channels_loop_idx = 0; channels_loop_idx <= channels_loop_times; - channels_loop_idx++) { - int actual_channels_num = (channels_loop_idx == channels_loop_times) - ? rem_channels - : nram_channels_limit; - if (actual_channels_num == 0) { - break; - } - const int *argmax_cur_loop = - argmax_cur_voxel + nram_channels_limit * channels_loop_idx; - const T *grad_out_cur_loop = - grad_out_cur_voxel + nram_channels_limit * channels_loop_idx; - __memcpy((void *)nram_argmax_cur_loop, (void *)argmax_cur_loop, - actual_channels_num * sizeof(int), GDRAM2NRAM); - __memcpy((void *)nram_grad_out_cur_loop, (void *)grad_out_cur_loop, - actual_channels_num * sizeof(T), GDRAM2NRAM); - - for (int channel_idx = 0; channel_idx < actual_channels_num; - channel_idx++) { - int *nram_argmax_cur_channel = nram_argmax_cur_loop + channel_idx; - T *nram_grad_out_cur_channel = nram_grad_out_cur_loop + channel_idx; - if (nram_argmax_cur_channel[0] == -1) { - continue; - } - T *grad_in_cur_channel = - grad_in + nram_argmax_cur_channel[0] * channels + - nram_channels_limit * channels_loop_idx + channel_idx; - __bang_atomic_add((T *)nram_grad_in_cur_channel, - (T *)grad_in_cur_channel, - (T *)(nram_grad_out_cur_channel), 1); - } - } - } -} - -template -__mlu_global__ void MLUUnion1KernelRoiawareAvgPool3dBackward( - const int boxes_num, const int out_x, const int out_y, const int out_z, - const int channels, const int max_pts_each_voxel, - const int *pts_idx_of_voxels, const T *grad_out, T *grad_in) { - // params (int)pts_idx_of_voxels: (boxes_num, out_x, out_y, out_z, - // max_pts_each_voxel) params (T)grad_out: (boxes_num, out_x, out_y, out_z, - // channels) params (T)grad_in: (pts_num, channels) - - // make sure that memcore is not used - if (coreId == 0x80) { - return; - } - int align_num = NFU_ALIGN_SIZE / sizeof(T); - int align_max_pts_each_voxel = PAD_UP(max_pts_each_voxel, align_num); - int nram_channels_limit = PAD_DOWN( - (MAX_NRAM_SIZE - align_max_pts_each_voxel * sizeof(int)) / 2 / sizeof(T), - align_num); - int *nram_pts_idx_cur_voxel = (int *)data_nram; - // nram_pts_idx_cur_voxel [align_max_pts_each_voxel] - T *nram_grad_out_cur_loop = - (T *)((int *)nram_pts_idx_cur_voxel + align_max_pts_each_voxel); - // nram_grad_out_cur_loop [nram_channels_limit] - T *nram_grad_in_cur_loop = (T *)nram_grad_out_cur_loop + nram_channels_limit; - // nram_grad_in_cur_loop [nram_channels_limit] - int channels_loop_times = channels / nram_channels_limit; - int rem_channels = channels % nram_channels_limit; - int voxels_num = boxes_num * out_x * out_y * out_z; - - for (int voxel_index = taskId; voxel_index < voxels_num; - voxel_index += taskDim) { - const T *grad_out_cur_voxel = grad_out + voxel_index * channels; - const int *pts_idx_cur_voxel = - pts_idx_of_voxels + voxel_index * max_pts_each_voxel; - __memcpy((void *)nram_pts_idx_cur_voxel, (void *)pts_idx_cur_voxel, - max_pts_each_voxel * sizeof(int), GDRAM2NRAM); - int total_pts_of_voxel = nram_pts_idx_cur_voxel[0]; - if (total_pts_of_voxel <= 0) { - continue; - } - float cur_grad = 1.0 / ((float)total_pts_of_voxel); - - for (int channels_loop_idx = 0; channels_loop_idx <= channels_loop_times; - channels_loop_idx++) { - int actual_channels_num = (channels_loop_idx == channels_loop_times) - ? rem_channels - : nram_channels_limit; - if (actual_channels_num == 0) { - break; - } - const T *grad_out_cur_loop = - grad_out_cur_voxel + nram_channels_limit * channels_loop_idx; - __memcpy((void *)nram_grad_in_cur_loop, (void *)grad_out_cur_loop, - actual_channels_num * sizeof(T), GDRAM2NRAM); - - int align_actual_channels_num = PAD_UP(actual_channels_num, align_num); - - if (sizeof(T) == sizeof(half)) { - __bang_half2float((float *)nram_grad_out_cur_loop, - (half *)nram_grad_in_cur_loop, - align_actual_channels_num); - __bang_mul_scalar((float *)nram_grad_out_cur_loop, - (float *)nram_grad_out_cur_loop, (float)cur_grad, - align_actual_channels_num); - convertFloat2half((half *)nram_grad_out_cur_loop, - (float *)nram_grad_out_cur_loop, - align_actual_channels_num); - } else { - __bang_mul_scalar((float *)nram_grad_out_cur_loop, - (float *)nram_grad_in_cur_loop, (float)cur_grad, - align_actual_channels_num); - } - for (int k = 1; k <= total_pts_of_voxel; k++) { - T *grad_in_cur_loop = grad_in + nram_pts_idx_cur_voxel[k] * channels + - nram_channels_limit * channels_loop_idx; - __bang_atomic_add((T *)nram_grad_in_cur_loop, (T *)grad_in_cur_loop, - (T *)nram_grad_out_cur_loop, actual_channels_num); - } - } - } -} - -void KernelRoiawarePool3dBackward( - cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, - const cnrtDataType_t d_type, const int pool_method, const int boxes_num, - const int out_x, const int out_y, const int out_z, const int channels, - const int max_pts_each_voxel, const int *pts_idx_of_voxels, - const int *argmax, const void *grad_out, void *grad_in) { - if (pool_method == 0) { - switch (d_type) { - case CNRT_FLOAT32: { - MLUUnion1KernelRoiawareMaxPool3dBackward - <<>>(boxes_num, out_x, out_y, out_z, channels, - (int *)argmax, (float *)grad_out, - (float *)grad_in); - }; break; - case CNRT_FLOAT16: { - MLUUnion1KernelRoiawareMaxPool3dBackward - <<>>(boxes_num, out_x, out_y, out_z, channels, - (int *)argmax, (half *)grad_out, - (half *)grad_in); - }; break; - default: { - break; - } - } - } else { - switch (d_type) { - case CNRT_FLOAT32: { - MLUUnion1KernelRoiawareAvgPool3dBackward - <<>>( - boxes_num, out_x, out_y, out_z, channels, max_pts_each_voxel, - (int *)pts_idx_of_voxels, (float *)grad_out, (float *)grad_in); - }; break; - case CNRT_FLOAT16: { - MLUUnion1KernelRoiawareAvgPool3dBackward - <<>>( - boxes_num, out_x, out_y, out_z, channels, max_pts_each_voxel, - (int *)pts_idx_of_voxels, (half *)grad_out, (half *)grad_in); - }; break; - default: { - break; - } - } - } -} diff --git a/mmcv/ops/csrc/pytorch/mlu/roiaware_pool3d_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/roiaware_pool3d_mlu.cpp index 62cb2dc62e..a1c4da4ca3 100644 --- a/mmcv/ops/csrc/pytorch/mlu/roiaware_pool3d_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/roiaware_pool3d_mlu.cpp @@ -9,49 +9,7 @@ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ -#include "pytorch_device_registry.hpp" -#include "pytorch_mlu_helper.hpp" - -void KernelPtsIdxOfVoxels(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, - cnrtQueue_t queue, const cnrtDataType_t d_type, - const int pool_method, const int boxes_num, - const int pts_num, const int max_pts_each_voxel, - const int out_x, const int out_y, const int out_z, - const void *rois, const void *pts, - int *pts_idx_of_voxels); - -void KernelRoiawarePool3dForward( - cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, - const cnrtDataType_t d_type, const int pool_method, const int boxes_num, - const int pts_num, const int channels, const int max_pts_each_voxel, - const int out_x, const int out_y, const int out_z, const void *pts_feature, - const int *pts_idx_of_voxels, void *pooled_features, int *argmax); - -// policy function -static void kernelPtsIdxOfVoxelsPolicyFunc(const int boxes_num, - cnrtDim3_t *k_dim, - cnrtFunctionType_t *k_type) { - unsigned int core_num = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); - unsigned int cluster_num = torch_mlu::getDeviceAttr(cnrtAttrClusterCount); - *k_type = CNRT_FUNC_TYPE_UNION1; - k_dim->x = core_num; - unsigned int use_cluster = (boxes_num + core_num - 1) / core_num; - k_dim->y = use_cluster > cluster_num ? cluster_num : use_cluster; - k_dim->z = 1; -} - -static void kernelRoiawarePool3dForwardPolicyFunc( - const int boxes_num, const int out_x, const int out_y, const int out_z, - cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type) { - unsigned int core_num = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); - unsigned int cluster_num = torch_mlu::getDeviceAttr(cnrtAttrClusterCount); - *k_type = CNRT_FUNC_TYPE_UNION1; - k_dim->x = core_num; - const int voxels_num = boxes_num * out_x * out_y * out_z; - unsigned int use_cluster = (voxels_num + core_num - 1) / core_num; - k_dim->y = use_cluster > cluster_num ? cluster_num : use_cluster; - k_dim->z = 1; -} +#include "mlu_common_helper.h" void RoiawarePool3dForwardMLUKernelLauncher( const int pool_method, const int boxes_num, const int pts_num, @@ -59,168 +17,65 @@ void RoiawarePool3dForwardMLUKernelLauncher( const int out_y, const int out_z, const Tensor rois, const Tensor pts, const Tensor pts_feature, Tensor pts_idx_of_voxels, Tensor pooled_features, Tensor argmax) { - // check datatype - TORCH_CHECK(((pts.scalar_type() == rois.scalar_type()) && - (pts_feature.scalar_type() == rois.scalar_type()) && - (pooled_features.scalar_type() == rois.scalar_type())), - "data types of rois, rois, pts_feature and pooled_features " - "should be the same, ", - "but now rois type is ", rois.scalar_type(), ", pts type is ", - pts.scalar_type(), ", pts_feature type is ", - pts_feature.scalar_type(), ", pooled_features type is ", - pooled_features.scalar_type(), "."); - TORCH_CHECK( - (rois.scalar_type() == at::kFloat || rois.scalar_type() == at::kHalf), - "rois type should be Float or Half, got ", rois.scalar_type(), "."); - TORCH_CHECK((pts_idx_of_voxels.scalar_type() == at::kInt), - "pts_idx_of_voxels type should be Int, got ", - pts_idx_of_voxels.scalar_type(), "."); - // check dim - TORCH_CHECK(rois.dim() == 2, "rois should be a 2D tensor, got ", rois.dim(), - "D."); - TORCH_CHECK(pts.dim() == 2, "pts should be a 2D tensor, got ", pts.dim(), - "D."); - TORCH_CHECK(pts_feature.dim() == 2, "pts_feature should be a 2D tensor, got ", - pts_feature.dim(), "D."); - TORCH_CHECK(pts_idx_of_voxels.dim() == 5, - "pts_idx_of_voxels should be a 5D tensor, got ", - pts_idx_of_voxels.dim(), "D."); - TORCH_CHECK(pooled_features.dim() == 5, - "pooled_features should be a 5D tensor, got ", - pooled_features.dim(), "D."); - // check shape - TORCH_CHECK(((rois.size(0) == boxes_num) && (rois.size(1) == 7)), - "the dimensions of rois should be (boxes_num, 7), ", "but got (", - rois.size(0), ", ", rois.size(1), ") ."); - TORCH_CHECK(((pts.size(0) == pts_num) && (pts.size(1) == 3)), - "the dimensions of pts should be (pts_num, 3), ", "but got (", - pts.size(0), ",", pts.size(1), ")."); - TORCH_CHECK( - ((pts_feature.size(0) == pts_num) && (pts_feature.size(1) == channels)), - "the dimensions of pts_feature should be (pts_num, channels), ", - "but got (", pts_feature.size(0), ",", pts_feature.size(1), ")."); - TORCH_CHECK(((pts_idx_of_voxels.size(0) == boxes_num) && - (pts_idx_of_voxels.size(1) == out_x) && - (pts_idx_of_voxels.size(2) == out_y) && - (pts_idx_of_voxels.size(3) == out_z) && - (pts_idx_of_voxels.size(4) == max_pts_each_voxel)), - "the dimensions of pts_idx_of_voxels should be (boxes_num, " - "out_x, out_y, out_z, max_pts_each_voxel), ", - "but got (", pts_idx_of_voxels.size(0), ",", - pts_idx_of_voxels.size(1), ",", pts_idx_of_voxels.size(2), ",", - pts_idx_of_voxels.size(3), ",", pts_idx_of_voxels.size(4), ")."); - TORCH_CHECK(((pooled_features.size(0) == boxes_num) && - (pooled_features.size(1) == out_x) && - (pooled_features.size(2) == out_y) && - (pooled_features.size(3) == out_z) && - (pooled_features.size(4) == channels)), - "the dimensions of pooled_features should be (boxes_num, out_x, " - "out_y, out_z, channels), ", - "but got (", pooled_features.size(0), ",", - pooled_features.size(1), ",", pooled_features.size(2), ",", - pooled_features.size(3), ",", pooled_features.size(4), ")."); - // check other params : pool_mothod - TORCH_CHECK(((pool_method == 0) || (pool_method == 1)), - "the num of pool_method should be 0(max) or 1(avg), ", "but got ", - pool_method, "."); - // check large tensor - const size_t max_input_size = 2147483648; - TORCH_CHECK(rois.numel() < max_input_size, - "rois element num should be less than 2^31, got ", rois.numel(), - "."); - TORCH_CHECK(pts.numel() < max_input_size, - "pts element num should be less than 2^31, got ", pts.numel(), - "."); - TORCH_CHECK(pts_feature.numel() < max_input_size, - "pts_feature element num should be less than 2^31, got ", - pts_feature.numel(), "."); - TORCH_CHECK(pts_idx_of_voxels.numel() < max_input_size, - "pts_idx_of_voxels element num should be less than 2^31, got ", - pts_idx_of_voxels.numel(), "."); - TORCH_CHECK(pooled_features.numel() < max_input_size, - "pooled_features element num should be less than 2^31, got ", - pooled_features.numel(), "."); - // check zero element - TORCH_CHECK(rois.numel() != 0, "rois.numel() should not be zero, got ", - rois.numel()); - TORCH_CHECK(pts.numel() != 0, "pts.numel() should not be zero, got ", - pts.numel()); - TORCH_CHECK(pts_feature.numel() != 0, - "pts_feature.numel() should not be zero, got ", - pts_feature.numel()); - TORCH_CHECK(pts_idx_of_voxels.numel() != 0, - "pts_idx_of_voxels.numel() should not be zero, got ", - pts_idx_of_voxels.numel()); - TORCH_CHECK(pooled_features.numel() != 0, - "pooled_features.numel() should not be zero, got ", - pooled_features.numel()); - if (pool_method == 0) { - // check datatype - TORCH_CHECK((argmax.scalar_type() == at::kInt), - "argmax type should be Int, got ", argmax.scalar_type(), "."); - // check dim - TORCH_CHECK(argmax.dim() == 5, "argmax should be a 5D tensor, got ", - argmax.dim(), "D."); - // check shape - TORCH_CHECK(((argmax.size(0) == boxes_num) && (argmax.size(1) == out_x) && - (argmax.size(2) == out_y) && (argmax.size(3) == out_z) && - (argmax.size(4) == channels)), - "the dimensions of argmax should be (boxes_num, out_x, out_y, " - "out_z, channels), ", - "but got (", argmax.size(0), ",", argmax.size(1), ",", - argmax.size(2), ",", argmax.size(3), ",", argmax.size(4), ")."); - // check large tensor - TORCH_CHECK(argmax.numel() < max_input_size, - "argmax element num should be less than 2^31, got ", - argmax.numel(), "."); - // check zero element - TORCH_CHECK(argmax.numel() != 0, "argmax.numel() should not be zero, got ", - argmax.numel()); - // when pool_method is 0, which is max pool, init argmax data value to -1 - argmax.fill_(static_cast(-1)); - } - // calculate task one dimension - cnrtDim3_t k1_dim; - cnrtFunctionType_t k1_type; - kernelPtsIdxOfVoxelsPolicyFunc(boxes_num, &k1_dim, &k1_type); - cnrtDim3_t k2_dim; - cnrtFunctionType_t k2_type; - kernelRoiawarePool3dForwardPolicyFunc(boxes_num, out_x, out_y, out_z, &k2_dim, - &k2_type); - // get compute queue - auto queue = torch_mlu::getCurQueue(); - // get ptr of tensors - auto rois_impl = torch_mlu::getMluTensorImpl(rois); + // get compute handle + auto handle = mluOpGetCurrentHandle(); + + auto rois_contiguous = + torch_mlu::cnnl::ops::cnnl_contiguous(rois, rois.suggest_memory_format()); + auto pts_contiguous = + torch_mlu::cnnl::ops::cnnl_contiguous(pts, pts.suggest_memory_format()); + auto pts_feature_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + pts_feature, pts_feature.suggest_memory_format()); + auto argmax_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + argmax, argmax.suggest_memory_format()); + auto pts_idx_of_voxels_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + pts_idx_of_voxels, pts_idx_of_voxels.suggest_memory_format()); + auto pooled_features_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + pooled_features, pooled_features.suggest_memory_format()); + + MluOpTensorDescriptor rois_desc, pts_desc, pts_feature_desc, argmax_desc, + pts_idx_of_voxels_desc, pooled_features_desc; + rois_desc.set(rois_contiguous); + pts_desc.set(pts_contiguous); + pts_feature_desc.set(pts_feature_contiguous); + argmax_desc.set(argmax_contiguous); + pts_idx_of_voxels_desc.set(pts_idx_of_voxels_contiguous); + pooled_features_desc.set(pooled_features_contiguous); + + // allocate extra space for workspace + size_t workspace_size = 0; + mluOpGetRoiawarePool3dForwardWorkspaceSize( + handle, rois_desc.desc(), pts_desc.desc(), pts_feature_desc.desc(), + &workspace_size); + + auto workspace = at::empty(workspace_size, rois.options().dtype(at::kByte)); + auto workspace_impl = torch_mlu::getMluTensorImpl(workspace); + auto workspace_ptr = workspace_impl->cnnlMalloc(); + + auto rois_impl = torch_mlu::getMluTensorImpl(rois_contiguous); + auto pts_impl = torch_mlu::getMluTensorImpl(pts_contiguous); + auto pts_feature_impl = torch_mlu::getMluTensorImpl(pts_feature_contiguous); + auto argmax_impl = torch_mlu::getMluTensorImpl(argmax_contiguous); + auto pts_idx_of_voxels_impl = + torch_mlu::getMluTensorImpl(pts_idx_of_voxels_contiguous); + auto pooled_features_impl = + torch_mlu::getMluTensorImpl(pooled_features_contiguous); + auto rois_ptr = rois_impl->cnnlMalloc(); - // transpose points [pts_num, 3] -> [3, pts_num] - auto pts_ = pts.permute({1, 0}).contiguous(); - auto pts_impl = torch_mlu::getMluTensorImpl(pts_); auto pts_ptr = pts_impl->cnnlMalloc(); - // transpose points_features [pts_num, channels] -> [channels, pts_num] - auto pts_feature_ = pts_feature.permute({1, 0}).contiguous(); - auto pts_feature_impl = torch_mlu::getMluTensorImpl(pts_feature_); auto pts_feature_ptr = pts_feature_impl->cnnlMalloc(); - auto pts_idx_of_voxels_impl = torch_mlu::getMluTensorImpl(pts_idx_of_voxels); + auto argmax_ptr = argmax_impl->cnnlMalloc(); auto pts_idx_of_voxels_ptr = pts_idx_of_voxels_impl->cnnlMalloc(); - auto pooled_features_impl = torch_mlu::getMluTensorImpl(pooled_features); auto pooled_features_ptr = pooled_features_impl->cnnlMalloc(); - auto argmax_impl = torch_mlu::getMluTensorImpl(argmax); - auto argmax_ptr = argmax_impl->cnnlMalloc(); - // get compute dtype of input - cnrtDataType_t data_type = torch_mlu::toCnrtDtype(rois.dtype()); - // launch kernel PtsIdxOfVoxels - CNLOG(INFO) << "Launch Kernel MLUKernel PtsIdxOfVoxels<<<" << k1_dim.x << ", " - << k1_dim.y << ", " << k1_dim.z << ">>>"; - KernelPtsIdxOfVoxels(k1_dim, k1_type, queue, data_type, pool_method, - boxes_num, pts_num, max_pts_each_voxel, out_x, out_y, - out_z, rois_ptr, pts_ptr, (int *)pts_idx_of_voxels_ptr); - // launch kernel RoiawarePool3dForward - CNLOG(INFO) << "Launch Kernel MLUKernel RoiawarePool3dForward<<<" << k2_dim.x - << ", " << k2_dim.y << ", " << k2_dim.z << ">>>"; - KernelRoiawarePool3dForward( - k2_dim, k2_type, queue, data_type, pool_method, boxes_num, pts_num, - channels, max_pts_each_voxel, out_x, out_y, out_z, pts_feature_ptr, - (int *)pts_idx_of_voxels_ptr, pooled_features_ptr, (int *)argmax_ptr); + + CNLOG(INFO) << "Call mluOpRoiawarePool3dForward()."; + mluOpRoiawarePool3dForward( + handle, pool_method, boxes_num, pts_num, channels, rois_desc.desc(), + rois_ptr, pts_desc.desc(), pts_ptr, pts_feature_desc.desc(), + pts_feature_ptr, workspace_ptr, workspace_size, max_pts_each_voxel, out_x, + out_y, out_z, argmax_desc.desc(), argmax_ptr, + pts_idx_of_voxels_desc.desc(), pts_idx_of_voxels_ptr, + pooled_features_desc.desc(), pooled_features_ptr); } void roiaware_pool3d_forward_mlu(int boxes_num, int pts_num, int channels, @@ -245,136 +100,46 @@ void roiaware_pool3d_forward_impl(int boxes_num, int pts_num, int channels, REGISTER_DEVICE_IMPL(roiaware_pool3d_forward_impl, MLU, roiaware_pool3d_forward_mlu); -void KernelRoiawarePool3dBackward( - cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, - const cnrtDataType_t d_type, const int pool_method, const int boxes_num, - const int out_x, const int out_y, const int out_z, const int channels, - const int max_pts_each_voxel, const int *pts_idx_of_voxels, - const int *argmax, const void *grad_out, void *grad_in); - -static void kernelRoiawarePool3dBackwardPolicyFunc( - const int boxes_num, const int out_x, const int out_y, const int out_z, - cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type) { - unsigned int core_num = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); - unsigned int cluster_num = torch_mlu::getDeviceAttr(cnrtAttrClusterCount); - *k_type = CNRT_FUNC_TYPE_UNION1; - k_dim->x = core_num; - const int voxels_num = boxes_num * out_x * out_y * out_z; - unsigned int use_cluster = (voxels_num + core_num - 1) / core_num; - k_dim->y = use_cluster > cluster_num ? cluster_num : use_cluster; - k_dim->z = 1; -} - void RoiawarePool3dBackwardMLUKernelLauncher( int pool_method, int boxes_num, int out_x, int out_y, int out_z, int channels, int max_pts_each_voxel, const Tensor pts_idx_of_voxels, const Tensor argmax, const Tensor grad_out, Tensor grad_in) { - // check datatype - TORCH_CHECK((pts_idx_of_voxels.scalar_type() == at::kInt), - "pts_idx_of_voxels type should be Int, got ", - pts_idx_of_voxels.scalar_type(), "."); - TORCH_CHECK((argmax.scalar_type() == at::kInt), - "argmax type should be Int, got ", argmax.scalar_type(), "."); - TORCH_CHECK((grad_out.scalar_type() == at::kFloat || - grad_out.scalar_type() == at::kHalf), - "grad_out type should be Float or Half, got ", - grad_out.scalar_type(), "."); - TORCH_CHECK((grad_out.scalar_type() == grad_in.scalar_type()), - "data types of grad_out, grad_in, should be the same, ", - "but now grad_out type is ", grad_out.scalar_type(), - ", grad_in type is ", grad_in.scalar_type(), "."); - // check dim - TORCH_CHECK(pts_idx_of_voxels.dim() == 5, - "pts_idx_of_voxels should be a 5D tensor, got ", - pts_idx_of_voxels.dim(), "D."); - TORCH_CHECK(argmax.dim() == 5, "argmax should be a 5D tensor, got ", - argmax.dim(), "D."); - TORCH_CHECK(grad_out.dim() == 5, "grad_out should be a 5D tensor, got ", - grad_out.dim(), "D."); - TORCH_CHECK(grad_in.dim() == 2, "grad_in should be a 2D tensor, got ", - grad_in.dim(), "D."); - // check shape - TORCH_CHECK(((pts_idx_of_voxels.size(0) == boxes_num) && - (pts_idx_of_voxels.size(1) == out_x) && - (pts_idx_of_voxels.size(2) == out_y) && - (pts_idx_of_voxels.size(3) == out_z) && - (pts_idx_of_voxels.size(4) == max_pts_each_voxel)), - "the dimensions of pts_idx_of_voxels should be (boxes_num, " - "out_x, out_y, out_z, max_pts_each_voxel), ", - "but got (", pts_idx_of_voxels.size(0), ",", - pts_idx_of_voxels.size(1), ",", pts_idx_of_voxels.size(2), ",", - pts_idx_of_voxels.size(3), ",", pts_idx_of_voxels.size(4), ")."); - TORCH_CHECK(((argmax.size(0) == boxes_num) && (argmax.size(1) == out_x) && - (argmax.size(2) == out_y) && (argmax.size(3) == out_z) && - (argmax.size(4) == channels)), - "the dimensions of argmax should be (boxes_num, out_x, out_y, " - "out_z, channels), ", - "but got (", argmax.size(0), ",", argmax.size(1), ",", - argmax.size(2), ",", argmax.size(3), ",", argmax.size(4), ")."); - TORCH_CHECK(((grad_out.size(0) == boxes_num) && (grad_out.size(1) == out_x) && - (grad_out.size(2) == out_y) && (grad_out.size(3) == out_z) && - (grad_out.size(4) == channels)), - "the dimensions of grad_out should be (boxes_num, out_x, " - "out_y, out_z, channels), ", - "but got (", grad_out.size(0), ",", grad_out.size(1), ",", - grad_out.size(2), ",", grad_out.size(3), ",", grad_out.size(4), - ")."); - TORCH_CHECK((grad_in.size(1) == channels), - "the 1st dimensions of grad_in should be channels, ", "but got ", - grad_in.size(1), "."); - // check other params : pool_mothod - TORCH_CHECK(((pool_method == 0) || (pool_method == 1)), - "the num of pool_method should be 0(max) or 1(avg), ", "but got ", - pool_method, "."); - // check large tensor - const size_t max_input_size = 2147483648; - TORCH_CHECK(pts_idx_of_voxels.numel() < max_input_size, - "pts_idx_of_voxels element num should be less than 2^31, got ", - pts_idx_of_voxels.numel(), "."); - TORCH_CHECK(argmax.numel() < max_input_size, - "argmax element num should be less than 2^31, got ", - argmax.numel(), "."); - TORCH_CHECK(grad_out.numel() < max_input_size, - "grad_out element num should be less than 2^31, got ", - grad_out.numel(), "."); - TORCH_CHECK(grad_in.numel() < max_input_size, - "grad_in element num should be less than 2^31, got ", - grad_in.numel(), "."); - // check zero element - TORCH_CHECK(pts_idx_of_voxels.numel() != 0, - "pts_idx_of_voxels.numel() should not be zero, got ", - pts_idx_of_voxels.numel()); - TORCH_CHECK(argmax.numel() != 0, "argmax.numel() should not be zero, got ", - argmax.numel()); - TORCH_CHECK(grad_out.numel() != 0, - "grad_out.numel() should not be zero, got ", grad_out.numel()); - TORCH_CHECK(grad_in.numel() != 0, "grad_in.numel() should not be zero, got ", - grad_in.numel()); - // calculate task one dimension - cnrtDim3_t k_dim; - cnrtFunctionType_t k_type; - kernelRoiawarePool3dBackwardPolicyFunc(boxes_num, out_x, out_y, out_z, &k_dim, - &k_type); - // get compute queue - auto queue = torch_mlu::getCurQueue(); - // transpose points_features [pts_num, channels] -> [channels, pts_num] - auto pts_idx_of_voxels_impl = torch_mlu::getMluTensorImpl(pts_idx_of_voxels); + // get compute handle + auto handle = mluOpGetCurrentHandle(); + auto pts_idx_of_voxels_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + pts_idx_of_voxels, pts_idx_of_voxels.suggest_memory_format()); + auto argmax_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + argmax, argmax.suggest_memory_format()); + auto grad_out_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + grad_out, grad_out.suggest_memory_format()); + auto grad_in_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + grad_in, grad_in.suggest_memory_format()); + + MluOpTensorDescriptor pts_idx_of_voxels_desc, argmax_desc, grad_out_desc, + grad_in_desc; + + pts_idx_of_voxels_desc.set(pts_idx_of_voxels_contiguous); + argmax_desc.set(argmax_contiguous); + grad_out_desc.set(grad_out_contiguous); + grad_in_desc.set(grad_in_contiguous); + + auto pts_idx_of_voxels_impl = + torch_mlu::getMluTensorImpl(pts_idx_of_voxels_contiguous); + auto argmax_impl = torch_mlu::getMluTensorImpl(argmax_contiguous); + auto grad_out_impl = torch_mlu::getMluTensorImpl(grad_out_contiguous); + auto grad_in_impl = torch_mlu::getMluTensorImpl(grad_in_contiguous); + auto pts_idx_of_voxels_ptr = pts_idx_of_voxels_impl->cnnlMalloc(); - auto argmax_impl = torch_mlu::getMluTensorImpl(argmax); auto argmax_ptr = argmax_impl->cnnlMalloc(); - auto grad_out_impl = torch_mlu::getMluTensorImpl(grad_out); auto grad_out_ptr = grad_out_impl->cnnlMalloc(); - auto grad_in_impl = torch_mlu::getMluTensorImpl(grad_in); auto grad_in_ptr = grad_in_impl->cnnlMalloc(); - // get compute dtype of input - cnrtDataType_t data_type = torch_mlu::toCnrtDtype(grad_out.dtype()); - // launch kernel RoiawarePool3dForward - CNLOG(INFO) << "Launch Kernel MLUKernel RoiawarePool3dBackward<<<" << k_dim.x - << ", " << k_dim.y << ", " << k_dim.z << ">>>"; - KernelRoiawarePool3dBackward(k_dim, k_type, queue, data_type, pool_method, - boxes_num, out_x, out_y, out_z, channels, - max_pts_each_voxel, (int *)pts_idx_of_voxels_ptr, - (int *)argmax_ptr, grad_out_ptr, grad_in_ptr); + + CNLOG(INFO) << "Call mluOpRoiawarePool3dBackward()."; + mluOpRoiawarePool3dBackward( + handle, pool_method, boxes_num, out_x, out_y, out_z, channels, + max_pts_each_voxel, pts_idx_of_voxels_desc.desc(), pts_idx_of_voxels_ptr, + argmax_desc.desc(), argmax_ptr, grad_out_desc.desc(), grad_out_ptr, + grad_in_desc.desc(), grad_in_ptr); } void roiaware_pool3d_backward_mlu(int boxes_num, int out_x, int out_y,