diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index 384ef3f41b..a1ac2ca8e3 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -45,7 +45,7 @@ We implement common ops used in detection, segmentation, etc. | RoIAlignRotated | √ | √ | √ | | | RiRoIAlignRotated | | √ | | | | RoIAlign | √ | √ | √ | | -| RoIAwarePool3d | | √ | | | +| RoIAwarePool3d | | √ | √ | | | SAConv2d | | √ | | | | SigmoidFocalLoss | | √ | √ | | | SoftmaxFocalLoss | | √ | | | diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index f8a9e864e6..c199356890 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -45,7 +45,7 @@ MMCV 提供了检测、分割等任务中常用的算子 | RoIAlignRotated | √ | √ | √ | | | RiRoIAlignRotated | | √ | | | | RoIAlign | √ | √ | √ | | -| RoIAwarePool3d | | √ | | | +| RoIAwarePool3d | | √ | √ | | | SAConv2d | | √ | | | | SigmoidFocalLoss | | √ | √ | | | SoftmaxFocalLoss | | √ | | | diff --git a/mmcv/ops/csrc/common/mlu/roiaware_pool3d_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/roiaware_pool3d_mlu_kernel.mlu new file mode 100644 index 0000000000..4c1edf0bf5 --- /dev/null +++ b/mmcv/ops/csrc/common/mlu/roiaware_pool3d_mlu_kernel.mlu @@ -0,0 +1,747 @@ +/************************************************************************* + * Copyright (C) 2022 Cambricon. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "common_mlu_helper.hpp" + +#define ROI_OFFSET 7 +#define FLOAT_NRAM_BUFFER_NUM 14 +#define HALF_NRAM_BUFFER_NUM 25 +#define ALIGN_NUM 64 + +__nram__ char data_nram[MAX_NRAM_SIZE]; + +template +__mlu_global__ void MLUUnion1KernelPtsIdxOfVoxels( + const int pool_method, const int boxes_num, const int pts_num, + const int max_pts_each_voxel, const int out_x, const int out_y, + const int out_z, const T *rois, const T *pts, int *pts_idx_of_voxels) { + // params (T)rois: (boxes_num, 7) + // params (T)pts: (3, pts_num) + // params (int)pts_idx_of_voxels: (boxes_num, out_x, out_y, out_z, + // max_pts_each_voxel) + + // make sure that memcore is not used + if (coreId == 0x80) { + return; + } + int nram_pts_num = 0; + if (sizeof(T) == sizeof(float)) { + nram_pts_num = PAD_DOWN( + (MAX_NRAM_SIZE / sizeof(float) / FLOAT_NRAM_BUFFER_NUM), ALIGN_NUM); + } else { + nram_pts_num = PAD_DOWN( + (MAX_NRAM_SIZE / sizeof(half) / HALF_NRAM_BUFFER_NUM), ALIGN_NUM); + } + + char *X = NULL; + char *Y = NULL; + char *Z = NULL; + char *local_X = NULL; + char *local_Y = NULL; + char *local_Z = NULL; + char *nram_pts_in_flag = NULL; + float *temp_buffer1 = NULL; + float *temp_buffer2 = NULL; + float *temp_buffer3 = NULL; + float *temp_buffer4 = NULL; + float *temp_buffer5 = NULL; + float *nram_voxel_offset = NULL; + int *nram_pts_idx_seq = NULL; + float *fp_local_X = NULL; + float *fp_local_Y = NULL; + float *fp_local_Z = NULL; + float *fp_nram_pts_in_flag = NULL; + if (sizeof(T) == sizeof(float)) { + X = (char *)((float *)data_nram); + Y = (char *)((float *)data_nram + nram_pts_num); + Z = (char *)((float *)data_nram + nram_pts_num * 2); + local_X = (char *)((float *)data_nram + nram_pts_num * 3); + local_Y = (char *)((float *)data_nram + nram_pts_num * 4); + local_Z = (char *)((float *)data_nram + nram_pts_num * 5); + nram_pts_in_flag = (char *)((float *)data_nram + nram_pts_num * 6); + temp_buffer1 = (float *)data_nram + nram_pts_num * 7; + temp_buffer2 = (float *)data_nram + nram_pts_num * 8; + temp_buffer3 = (float *)data_nram + nram_pts_num * 9; + temp_buffer4 = (float *)data_nram + nram_pts_num * 10; + temp_buffer5 = (float *)data_nram + nram_pts_num * 11; + nram_voxel_offset = (float *)data_nram + nram_pts_num * 12; + nram_pts_idx_seq = (int *)((float *)data_nram + nram_pts_num * 13); + fp_local_X = (float *)local_X; + fp_local_Y = (float *)local_Y; + fp_local_Z = (float *)local_Z; + fp_nram_pts_in_flag = (float *)nram_pts_in_flag; + } else { + X = (char *)((half *)data_nram); + Y = (char *)((half *)data_nram + nram_pts_num); + Z = (char *)((half *)data_nram + nram_pts_num * 2); + local_X = (char *)((half *)data_nram + nram_pts_num * 4); + local_Y = (char *)((half *)data_nram + nram_pts_num * 6); + local_Z = (char *)((half *)data_nram + nram_pts_num * 8); + nram_pts_in_flag = (char *)((half *)data_nram + nram_pts_num * 10); + temp_buffer1 = (float *)((half *)data_nram + nram_pts_num * 11); + temp_buffer2 = (float *)((half *)data_nram + nram_pts_num * 13); + temp_buffer3 = (float *)((half *)data_nram + nram_pts_num * 15); + temp_buffer4 = (float *)((half *)data_nram + nram_pts_num * 17); + temp_buffer5 = (float *)((half *)data_nram + nram_pts_num * 19); + nram_voxel_offset = (float *)((half *)data_nram + nram_pts_num * 21); + nram_pts_idx_seq = (int *)((half *)data_nram + nram_pts_num * 23); + fp_local_X = (float *)((half *)local_X - nram_pts_num); + fp_local_Y = (float *)((half *)local_Y - nram_pts_num); + fp_local_Z = (float *)((half *)local_Z - nram_pts_num); + fp_nram_pts_in_flag = (float *)((half *)nram_pts_in_flag - nram_pts_num); + } + + for (int i = 0; i < nram_pts_num; i++) { + nram_pts_idx_seq[i] = i; + } + + int nram_pts_loop_times = pts_num / nram_pts_num; + int rem_nram_num = pts_num % nram_pts_num; + + for (int roi_index = taskId; roi_index < boxes_num; roi_index += taskDim) { + const T *cur_roi = rois + roi_index * ROI_OFFSET; + T cx = cur_roi[0]; + T cy = cur_roi[1]; + T cz = cur_roi[2]; + T dx = cur_roi[3]; + T dy = cur_roi[4]; + T dz = cur_roi[5]; + T rz = cur_roi[6]; + + T dx_2 = dx / 2.0; + T dy_2 = dy / 2.0; + T dz_2 = dz / 2.0; + + for (int loop_idx = 0; loop_idx <= nram_pts_loop_times; loop_idx++) { + int load_pts_num = + (loop_idx == nram_pts_loop_times) ? rem_nram_num : nram_pts_num; + if (load_pts_num == 0) { + break; + } + int pts_offset_cur_loop = nram_pts_num * loop_idx; + int compute_pts_num = (loop_idx == nram_pts_loop_times) + ? PAD_UP(rem_nram_num, ALIGN_NUM) + : nram_pts_num; + // load pts + __memcpy((void *)X, (T *)pts + pts_offset_cur_loop, + load_pts_num * sizeof(T), GDRAM2NRAM); + __memcpy((void *)Y, (T *)pts + pts_num + pts_offset_cur_loop, + load_pts_num * sizeof(T), GDRAM2NRAM); + __memcpy((void *)Z, (T *)pts + pts_num * 2 + pts_offset_cur_loop, + load_pts_num * sizeof(T), GDRAM2NRAM); + // fabs(local_z) + __bang_sub_scalar((T *)local_Z, (T *)Z, (T)cz, compute_pts_num); + __bang_sub_scalar((T *)temp_buffer1, (T *)Z, (T)(cz + dz_2), + compute_pts_num); + __bang_active_abs((T *)temp_buffer1, (T *)temp_buffer1, compute_pts_num); +#if __BANG_ARCH__ >= 322 + __bang_le_scalar((T *)nram_pts_in_flag, (T *)temp_buffer1, (T)(dz_2), + compute_pts_num); +#else + __bang_write_value((void *)temp_buffer2, compute_pts_num, (T)(dz_2)); + __bang_le((T *)nram_pts_in_flag, (T *)temp_buffer1, (T *)temp_buffer2, + compute_pts_num); +#endif + T cosa = std::cos(-rz); + T sina = std::sin(-rz); + __bang_sub_scalar((T *)temp_buffer3, (T *)X, (T)cx, compute_pts_num); + __bang_sub_scalar((T *)temp_buffer4, (T *)Y, (T)cy, compute_pts_num); + __bang_mul_scalar((T *)temp_buffer1, (T *)temp_buffer3, (T)cosa, + compute_pts_num); + __bang_mul_scalar((T *)temp_buffer2, (T *)temp_buffer4, (T)sina, + compute_pts_num); + // local_x + __bang_sub((T *)local_X, (T *)temp_buffer1, (T *)temp_buffer2, + compute_pts_num); + // fabs(local_x) + __bang_active_abs((T *)temp_buffer1, (T *)local_X, compute_pts_num); + // fabs(local_x) < dx/2 ? 1 : 0 +#if __BANG_ARCH__ >= 322 + __bang_lt_scalar((T *)temp_buffer1, (T *)temp_buffer1, (T)(dx_2), + compute_pts_num); +#else + __bang_write_value((void *)temp_buffer2, compute_pts_num, (T)(dx_2)); + __bang_lt((T *)temp_buffer1, (T *)temp_buffer1, (T *)temp_buffer2, + compute_pts_num); +#endif + __bang_and((T *)nram_pts_in_flag, (T *)nram_pts_in_flag, + (T *)temp_buffer1, + compute_pts_num); // flush res + + __bang_mul_scalar((T *)temp_buffer1, (T *)temp_buffer3, (T)sina, + compute_pts_num); + __bang_mul_scalar((T *)temp_buffer2, (T *)temp_buffer4, (T)cosa, + compute_pts_num); + // local_y + __bang_add((T *)local_Y, (T *)temp_buffer1, (T *)temp_buffer2, + compute_pts_num); + // fabs(local_y) + __bang_active_abs((T *)temp_buffer1, (T *)local_Y, compute_pts_num); + // fabs(local_y) < dy/2 ? 1 : 0 +#if __BANG_ARCH__ >= 322 + __bang_lt_scalar((T *)temp_buffer1, (T *)temp_buffer1, (T)(dy_2), + compute_pts_num); +#else + __bang_write_value((void *)temp_buffer2, compute_pts_num, (T)(dy_2)); + __bang_lt((T *)temp_buffer1, (T *)temp_buffer1, (T *)temp_buffer2, + compute_pts_num); +#endif + __bang_and((T *)nram_pts_in_flag, (T *)nram_pts_in_flag, + (T *)temp_buffer1, + compute_pts_num); // flush res + T x_res = dx / out_x; + T y_res = dy / out_y; + T z_res = dz / out_z; + __bang_add_scalar((T *)local_X, (T *)local_X, (T)(dx_2), compute_pts_num); + __bang_add_scalar((T *)local_Y, (T *)local_Y, (T)(dy_2), compute_pts_num); + // local_Z do not need to add dz/2.0 + +#if (__BANG_ARCH__ >= 322) && (__BANG_ARCH__ != 372) + __bang_div((T *)local_X, (T *)local_X, (T)x_res, compute_pts_num); + __bang_div((T *)local_Y, (T *)local_Y, (T)y_res, compute_pts_num); + __bang_div((T *)local_Z, (T *)local_Z, (T)z_res, compute_pts_num); +#else + __bang_mul_scalar((T *)local_X, (T *)local_X, (T)(1 / x_res), + compute_pts_num); + __bang_mul_scalar((T *)local_Y, (T *)local_Y, (T)(1 / y_res), + compute_pts_num); + __bang_mul_scalar((T *)local_Z, (T *)local_Z, (T)(1 / z_res), + compute_pts_num); +#endif + // float = float2int + int2float, half = half2int + int2float + if (sizeof(T) == sizeof(float)) { +#if __BANG_ARCH__ >= 322 + __bang_float2int32_tz((int *)temp_buffer1, (float *)local_X, + compute_pts_num, 0); + __bang_float2int32_tz((int *)temp_buffer2, (float *)local_Y, + compute_pts_num, 0); + __bang_float2int32_tz((int *)temp_buffer3, (float *)local_Z, + compute_pts_num, 0); + __bang_int322float_rn((float *)fp_local_X, (int *)temp_buffer1, + compute_pts_num, 0); + __bang_int322float_rn((float *)fp_local_Y, (int *)temp_buffer2, + compute_pts_num, 0); + __bang_int322float_rn((float *)fp_local_Z, (int *)temp_buffer3, + compute_pts_num, 0); +#else + convertFloat2Int((int *)temp_buffer1, (float *)temp_buffer2, + (float *)fp_local_X, (float *)temp_buffer3, + compute_pts_num); + convertFloat2Int((int *)temp_buffer2, (float *)temp_buffer3, + (float *)fp_local_Y, (float *)temp_buffer4, + compute_pts_num); + convertFloat2Int((int *)temp_buffer3, (float *)temp_buffer4, + (float *)fp_local_Z, (float *)temp_buffer5, + compute_pts_num); + convertInt2Float((float *)fp_local_X, (float *)temp_buffer4, + (int *)temp_buffer1, (float *)temp_buffer5, + compute_pts_num); + convertInt2Float((float *)fp_local_Y, (float *)temp_buffer4, + (int *)temp_buffer2, (float *)temp_buffer5, + compute_pts_num); + convertInt2Float((float *)fp_local_Z, (float *)temp_buffer4, + (int *)temp_buffer3, (float *)temp_buffer5, + compute_pts_num); +#endif + } else { + __bang_half2float((float *)temp_buffer4, (half *)nram_pts_in_flag, + compute_pts_num); + __bang_move((void *)fp_nram_pts_in_flag, (void *)temp_buffer4, + compute_pts_num * sizeof(float)); +#if __BANG_ARCH__ >= 322 + __bang_half2int32_tz((int *)temp_buffer1, (half *)local_X, + compute_pts_num, 0); + __bang_half2int32_tz((int *)temp_buffer2, (half *)local_Y, + compute_pts_num, 0); + __bang_half2int32_tz((int *)temp_buffer3, (half *)local_Z, + compute_pts_num, 0); + __bang_int322float_rn((float *)fp_local_X, (int *)temp_buffer1, + compute_pts_num, 0); + __bang_int322float_rn((float *)fp_local_Y, (int *)temp_buffer2, + compute_pts_num, 0); + __bang_int322float_rn((float *)fp_local_Z, (int *)temp_buffer3, + compute_pts_num, 0); +#else + __bang_half2int16_tz((int16_t *)temp_buffer1, (half *)local_X, + compute_pts_num, 0); + __bang_half2int16_tz((int16_t *)temp_buffer2, (half *)local_Y, + compute_pts_num, 0); + __bang_half2int16_tz((int16_t *)temp_buffer3, (half *)local_Z, + compute_pts_num, 0); + __bang_int162float((float *)fp_local_X, (int16_t *)temp_buffer1, + compute_pts_num, 0); + __bang_int162float((float *)fp_local_Y, (int16_t *)temp_buffer2, + compute_pts_num, 0); + __bang_int162float((float *)fp_local_Z, (int16_t *)temp_buffer3, + compute_pts_num, 0); +#endif + } + // process index >= 0 + __bang_write_value((float *)temp_buffer4, compute_pts_num, (float)0.0f); + __bang_maxequal((float *)fp_local_X, (float *)fp_local_X, + (float *)temp_buffer4, compute_pts_num); + __bang_maxequal((float *)fp_local_Y, (float *)fp_local_Y, + (float *)temp_buffer4, compute_pts_num); + __bang_maxequal((float *)fp_local_Z, (float *)fp_local_Z, + (float *)temp_buffer4, compute_pts_num); + // process index <= (out_x - 1) + __bang_write_value((float *)temp_buffer5, compute_pts_num, + (float)(out_x - 1)); + __bang_minequal((float *)fp_local_X, (float *)fp_local_X, + (float *)temp_buffer5, compute_pts_num); + __bang_write_value((float *)temp_buffer5, compute_pts_num, + (float)(out_y - 1)); + __bang_minequal((float *)fp_local_Y, (float *)fp_local_Y, + (float *)temp_buffer5, compute_pts_num); + __bang_write_value((float *)temp_buffer5, compute_pts_num, + (float)(out_z - 1)); + __bang_minequal((float *)fp_local_Z, (float *)fp_local_Z, + (float *)temp_buffer5, compute_pts_num); + __bang_mul_scalar((float *)temp_buffer1, (float *)fp_local_X, + (float)(out_y * out_z), compute_pts_num); + __bang_mul_scalar((float *)temp_buffer2, (float *)fp_local_Y, + (float)out_z, compute_pts_num); + __bang_mul_scalar((float *)temp_buffer3, (float *)fp_local_Z, (float)1.0, + compute_pts_num); + __bang_add((float *)nram_voxel_offset, (float *)temp_buffer1, + (float *)temp_buffer2, compute_pts_num); + __bang_add((float *)nram_voxel_offset, (float *)nram_voxel_offset, + (float *)temp_buffer3, compute_pts_num); + __bang_mul_scalar((float *)nram_voxel_offset, (float *)nram_voxel_offset, + (float)max_pts_each_voxel, compute_pts_num); + if (compute_pts_num != load_pts_num) { + __memset_nram((float *)fp_nram_pts_in_flag + load_pts_num, + compute_pts_num - load_pts_num, (float)0.0); + } + __bang_collect((float *)temp_buffer4, (float *)nram_pts_idx_seq, + (float *)fp_nram_pts_in_flag, compute_pts_num); + int pts_num_in_cur_roi = + (int)__bang_count((float *)fp_nram_pts_in_flag, compute_pts_num); + int *pts_idx_cur_voxels = + (int *)pts_idx_of_voxels + + roi_index * out_x * out_y * out_z * max_pts_each_voxel; + for (int idx = 0; idx < pts_num_in_cur_roi; idx++) { + int cur_pts_idx = *((int *)temp_buffer4 + idx); + int offset = (int)(*((float *)nram_voxel_offset + cur_pts_idx)); + int cnt = pts_idx_cur_voxels[offset]; + if (cnt < max_pts_each_voxel - 1) { + pts_idx_cur_voxels[offset + cnt + 1] = + cur_pts_idx + loop_idx * nram_pts_num; + pts_idx_cur_voxels[offset]++; + } + } + } + } +} + +template +__mlu_global__ void MLUUnion1KernelRoiawarePool3dForward( + const int pool_method, const int boxes_num, const int pts_num, + const int channels, const int max_pts_each_voxel, const int out_x, + const int out_y, const int out_z, const T *pts_feature, + const int *pts_idx_of_voxels, T *pooled_features, int *argmax) { + // params (T)pts_feature: (channels, pts_num) + // params (int)pts_idx_of_voxels: (boxes_num, out_x, out_y, out_z, + // max_pts_each_voxel) params (int)argmax: (boxes_num, out_x, out_y, out_z, + // channels) params (T)pooled_features: (boxes_num, out_x, out_y, out_z, + // channels) + + // make sure that memcore is not used + if (coreId == 0x80) { + return; + } + int align_num = NFU_ALIGN_SIZE / sizeof(T); + int align_max_pts_each_voxel = PAD_UP(max_pts_each_voxel, align_num); + int nram_channels_limit = + PAD_DOWN((MAX_NRAM_SIZE - 128 - + align_max_pts_each_voxel * (sizeof(int) + sizeof(T))) / + ((align_max_pts_each_voxel + 1) * sizeof(T) + sizeof(int)), + align_num); + int *nram_pts_idx_cur_voxel = (int *)data_nram; + // nram_pts_idx_cur_voxel [align_max_pts_each_voxel] + T *nram_max_pts_feature_tmp = + (T *)((int *)nram_pts_idx_cur_voxel + align_max_pts_each_voxel); + // nram_max_pts_feature_tmp [align_max_pts_each_voxel] + T *nram_pts_feature_in_voxel = + ((T *)nram_max_pts_feature_tmp + align_max_pts_each_voxel); + // nram_pts_feature_in_voxel [nram_channels_limit, align_max_pts_each_voxel] + T *nram_pooled_features_cur_voxel = + ((T *)nram_pts_feature_in_voxel + + nram_channels_limit * align_max_pts_each_voxel); + // nram_pooled_features_cur_voxel [nram_channels_limit] + int *nram_argmax_cur_voxel = + (int *)((T *)nram_pooled_features_cur_voxel + nram_channels_limit); + // nram_argmax_cur_voxel [nram_channels_limit] + char *one_pooled_feature = + (char *)((int *)nram_argmax_cur_voxel + nram_channels_limit); + // one_pooled_feature [128] + int channels_loop_times = channels / nram_channels_limit; + int rem_channels = channels % nram_channels_limit; + for (int voxel_index = taskId; + voxel_index < boxes_num * out_x * out_y * out_z; + voxel_index += taskDim) { + int *pts_idx_cur_voxels = + (int *)pts_idx_of_voxels + voxel_index * max_pts_each_voxel; + __memcpy((void *)nram_pts_idx_cur_voxel, (void *)pts_idx_cur_voxels, + max_pts_each_voxel * sizeof(int), GDRAM2NRAM); + int pts_num_cur_voxel = nram_pts_idx_cur_voxel[0]; + if (pts_num_cur_voxel == 0) { + continue; + } + for (int channels_loop_idx = 0; channels_loop_idx <= channels_loop_times; + channels_loop_idx++) { + int actual_channels_num = (channels_loop_idx == channels_loop_times) + ? rem_channels + : nram_channels_limit; + if (actual_channels_num == 0) { + break; + } + int channels_offset = nram_channels_limit * channels_loop_idx; + +#if ((__BANG_ARCH__ >= 200) && (__BANG_ARCH__ < 300)) + int compute_channels_num = (channels_loop_idx == channels_loop_times) + ? PAD_UP(rem_channels, align_num) + : nram_channels_limit; + if (pool_method == 0) { + __bang_write_value((void *)nram_pts_feature_in_voxel, + compute_channels_num * align_max_pts_each_voxel, + (T)-INFINITY); + } +#endif + + T *pts_feature_cur_loop = (T *)pts_feature + channels_offset * pts_num; + for (int idx = 0; idx < pts_num_cur_voxel; idx++) { + __memcpy((T *)nram_pts_feature_in_voxel + idx, + (T *)pts_feature_cur_loop + nram_pts_idx_cur_voxel[idx + 1], + sizeof(T), GDRAM2NRAM, align_max_pts_each_voxel * sizeof(T), + pts_num * sizeof(T), actual_channels_num - 1); + } + for (int channel_idx = 0; channel_idx < actual_channels_num; + channel_idx++) { + if (pool_method == 0) { +#if __BANG_ARCH__ >= 322 + __bang_argmax((T *)one_pooled_feature, + (T *)nram_pts_feature_in_voxel + + channel_idx * align_max_pts_each_voxel, + pts_num_cur_voxel); + T max_val = ((T *)one_pooled_feature)[0]; + int max_idx = (int)(*(uint32_t *)((T *)one_pooled_feature + 1)); + nram_pooled_features_cur_voxel[channel_idx] = + (max_val == -INFINITY) ? 0 : max_val; + nram_argmax_cur_voxel[channel_idx] = + (max_val == -INFINITY) ? -1 : nram_pts_idx_cur_voxel[max_idx + 1]; +#else + // __bang_max need align num on mlu200 series + if (sizeof(T) == sizeof(float)) { + __bang_max((float *)one_pooled_feature, + (float *)nram_pts_feature_in_voxel + + channel_idx * align_max_pts_each_voxel, + align_max_pts_each_voxel); + float max_val = ((float *)one_pooled_feature)[0]; + __bang_write_value((void *)nram_max_pts_feature_tmp, + align_max_pts_each_voxel, (float)max_val); + __bang_eq((float *)nram_max_pts_feature_tmp, + (float *)nram_pts_feature_in_voxel + + channel_idx * align_max_pts_each_voxel, + (float *)nram_max_pts_feature_tmp, + align_max_pts_each_voxel); + int max_idx = (int)__bang_findfirst1( + (float *)nram_max_pts_feature_tmp, align_max_pts_each_voxel); + nram_pooled_features_cur_voxel[channel_idx] = + (max_val == -INFINITY) ? 0 : max_val; + nram_argmax_cur_voxel[channel_idx] = + (max_val == -INFINITY) ? -1 + : nram_pts_idx_cur_voxel[max_idx + 1]; + } else { + int max_idx = -1; + float max_val = -INFINITY; + for (int k = 0; k < pts_num_cur_voxel; k++) { + float pts_feature_cur_channel = __half2float_rd( + *((half *)nram_pts_feature_in_voxel + + channel_idx * align_max_pts_each_voxel + k)); + if (pts_feature_cur_channel > max_val) { + max_val = pts_feature_cur_channel; + max_idx = k; + } + } + nram_pooled_features_cur_voxel[channel_idx] = + (max_idx == -1) ? 0 : max_val; + nram_argmax_cur_voxel[channel_idx] = + (max_idx == -1) ? -1 : nram_pts_idx_cur_voxel[max_idx + 1]; + } +#endif + } else if (pool_method == 1) { + float sum_val_cur_channel = 0; + for (int k = 0; k < pts_num_cur_voxel; k++) { + sum_val_cur_channel += static_cast( + ((T *)nram_pts_feature_in_voxel)[channel_idx * + align_max_pts_each_voxel + + k]); + } + nram_pooled_features_cur_voxel[channel_idx] = + (T)(sum_val_cur_channel / pts_num_cur_voxel); + } + } + // store + __memcpy((T *)pooled_features + voxel_index * channels + channels_offset, + (void *)nram_pooled_features_cur_voxel, + actual_channels_num * sizeof(T), NRAM2GDRAM); + if (pool_method == 0) { + __memcpy((int *)argmax + voxel_index * channels + channels_offset, + (void *)nram_argmax_cur_voxel, + actual_channels_num * sizeof(int), NRAM2GDRAM); + } + } + } +} + +void KernelPtsIdxOfVoxels(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, + cnrtQueue_t queue, const cnrtDataType_t d_type, + const int pool_method, const int boxes_num, + const int pts_num, const int max_pts_each_voxel, + const int out_x, const int out_y, const int out_z, + const void *rois, const void *pts, + int *pts_idx_of_voxels) { + switch (d_type) { + case CNRT_FLOAT32: { + MLUUnion1KernelPtsIdxOfVoxels<<>>( + pool_method, boxes_num, pts_num, max_pts_each_voxel, out_x, out_y, + out_z, (float *)rois, (float *)pts, (int *)pts_idx_of_voxels); + }; break; + case CNRT_FLOAT16: { + MLUUnion1KernelPtsIdxOfVoxels<<>>( + pool_method, boxes_num, pts_num, max_pts_each_voxel, out_x, out_y, + out_z, (half *)rois, (half *)pts, (int *)pts_idx_of_voxels); + }; break; + default: { + break; + } + } +} + +void KernelRoiawarePool3dForward( + cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, + const cnrtDataType_t d_type, const int pool_method, const int boxes_num, + const int pts_num, const int channels, const int max_pts_each_voxel, + const int out_x, const int out_y, const int out_z, const void *pts_feature, + const int *pts_idx_of_voxels, void *pooled_features, int *argmax) { + switch (d_type) { + case CNRT_FLOAT32: { + MLUUnion1KernelRoiawarePool3dForward<<>>( + pool_method, boxes_num, pts_num, channels, max_pts_each_voxel, out_x, + out_y, out_z, (float *)pts_feature, (int *)pts_idx_of_voxels, + (float *)pooled_features, (int *)argmax); + }; break; + case CNRT_FLOAT16: { + MLUUnion1KernelRoiawarePool3dForward<<>>( + pool_method, boxes_num, pts_num, channels, max_pts_each_voxel, out_x, + out_y, out_z, (half *)pts_feature, (int *)pts_idx_of_voxels, + (half *)pooled_features, (int *)argmax); + }; break; + default: { + break; + } + } +} + +template +__mlu_global__ void MLUUnion1KernelRoiawareMaxPool3dBackward( + const int boxes_num, const int out_x, const int out_y, const int out_z, + const int channels, const int *argmax, const T *grad_out, T *grad_in) { + // params (int)argmax: (boxes_num, out_x, out_y, out_z, channels) + // params (T)grad_out: (boxes_num, out_x, out_y, out_z, channels) + // params (T)grad_in: (pts_num, channels) + + // make sure that memcore is not used + if (coreId == 0x80) { + return; + } + int nram_channels_limit = + (MAX_NRAM_SIZE - sizeof(T) * 1) / (sizeof(T) + sizeof(int)); + int *nram_argmax_cur_loop = (int *)data_nram; + // nram_argmax_cur_loop [nram_channels_limit] + T *nram_grad_out_cur_loop = + (T *)((int *)nram_argmax_cur_loop + nram_channels_limit); + // nram_grad_out_cur_loop [nram_channels_limit] + T *nram_grad_in_cur_channel = + (T *)nram_grad_out_cur_loop + nram_channels_limit; + // nram_grad_in_cur_channel [1] + int channels_loop_times = channels / nram_channels_limit; + int rem_channels = channels % nram_channels_limit; + int voxels_num = boxes_num * out_x * out_y * out_z; + + for (int voxel_index = taskId; voxel_index < voxels_num; + voxel_index += taskDim) { + const int *argmax_cur_voxel = argmax + voxel_index * channels; + const T *grad_out_cur_voxel = grad_out + voxel_index * channels; + + for (int channels_loop_idx = 0; channels_loop_idx <= channels_loop_times; + channels_loop_idx++) { + int actual_channels_num = (channels_loop_idx == channels_loop_times) + ? rem_channels + : nram_channels_limit; + if (actual_channels_num == 0) { + break; + } + const int *argmax_cur_loop = + argmax_cur_voxel + nram_channels_limit * channels_loop_idx; + const T *grad_out_cur_loop = + grad_out_cur_voxel + nram_channels_limit * channels_loop_idx; + __memcpy((void *)nram_argmax_cur_loop, (void *)argmax_cur_loop, + actual_channels_num * sizeof(int), GDRAM2NRAM); + __memcpy((void *)nram_grad_out_cur_loop, (void *)grad_out_cur_loop, + actual_channels_num * sizeof(T), GDRAM2NRAM); + + for (int channel_idx = 0; channel_idx < actual_channels_num; + channel_idx++) { + int *nram_argmax_cur_channel = nram_argmax_cur_loop + channel_idx; + T *nram_grad_out_cur_channel = nram_grad_out_cur_loop + channel_idx; + if (nram_argmax_cur_channel[0] == -1) { + continue; + } + T *grad_in_cur_channel = + grad_in + nram_argmax_cur_channel[0] * channels + + nram_channels_limit * channels_loop_idx + channel_idx; + __bang_atomic_add((T *)nram_grad_in_cur_channel, + (T *)grad_in_cur_channel, + (T *)(nram_grad_out_cur_channel), 1); + } + } + } +} + +template +__mlu_global__ void MLUUnion1KernelRoiawareAvgPool3dBackward( + const int boxes_num, const int out_x, const int out_y, const int out_z, + const int channels, const int max_pts_each_voxel, + const int *pts_idx_of_voxels, const T *grad_out, T *grad_in) { + // params (int)pts_idx_of_voxels: (boxes_num, out_x, out_y, out_z, + // max_pts_each_voxel) params (T)grad_out: (boxes_num, out_x, out_y, out_z, + // channels) params (T)grad_in: (pts_num, channels) + + // make sure that memcore is not used + if (coreId == 0x80) { + return; + } + int align_num = NFU_ALIGN_SIZE / sizeof(T); + int align_max_pts_each_voxel = PAD_UP(max_pts_each_voxel, align_num); + int nram_channels_limit = PAD_DOWN( + (MAX_NRAM_SIZE - align_max_pts_each_voxel * sizeof(int)) / 2 / sizeof(T), + align_num); + int *nram_pts_idx_cur_voxel = (int *)data_nram; + // nram_pts_idx_cur_voxel [align_max_pts_each_voxel] + T *nram_grad_out_cur_loop = + (T *)((int *)nram_pts_idx_cur_voxel + align_max_pts_each_voxel); + // nram_grad_out_cur_loop [nram_channels_limit] + T *nram_grad_in_cur_loop = (T *)nram_grad_out_cur_loop + nram_channels_limit; + // nram_grad_in_cur_loop [nram_channels_limit] + int channels_loop_times = channels / nram_channels_limit; + int rem_channels = channels % nram_channels_limit; + int voxels_num = boxes_num * out_x * out_y * out_z; + + for (int voxel_index = taskId; voxel_index < voxels_num; + voxel_index += taskDim) { + const T *grad_out_cur_voxel = grad_out + voxel_index * channels; + const int *pts_idx_cur_voxel = + pts_idx_of_voxels + voxel_index * max_pts_each_voxel; + __memcpy((void *)nram_pts_idx_cur_voxel, (void *)pts_idx_cur_voxel, + max_pts_each_voxel * sizeof(int), GDRAM2NRAM); + int total_pts_of_voxel = nram_pts_idx_cur_voxel[0]; + if (total_pts_of_voxel <= 0) { + continue; + } + float cur_grad = 1.0 / ((float)total_pts_of_voxel); + + for (int channels_loop_idx = 0; channels_loop_idx <= channels_loop_times; + channels_loop_idx++) { + int actual_channels_num = (channels_loop_idx == channels_loop_times) + ? rem_channels + : nram_channels_limit; + if (actual_channels_num == 0) { + break; + } + const T *grad_out_cur_loop = + grad_out_cur_voxel + nram_channels_limit * channels_loop_idx; + __memcpy((void *)nram_grad_in_cur_loop, (void *)grad_out_cur_loop, + actual_channels_num * sizeof(T), GDRAM2NRAM); + + int align_actual_channels_num = PAD_UP(actual_channels_num, align_num); + + if (sizeof(T) == sizeof(half)) { + __bang_half2float((float *)nram_grad_out_cur_loop, + (half *)nram_grad_in_cur_loop, + align_actual_channels_num); + __bang_mul_scalar((float *)nram_grad_out_cur_loop, + (float *)nram_grad_out_cur_loop, (float)cur_grad, + align_actual_channels_num); + convertFloat2half((half *)nram_grad_out_cur_loop, + (float *)nram_grad_out_cur_loop, + align_actual_channels_num); + } else { + __bang_mul_scalar((float *)nram_grad_out_cur_loop, + (float *)nram_grad_in_cur_loop, (float)cur_grad, + align_actual_channels_num); + } + for (int k = 1; k <= total_pts_of_voxel; k++) { + T *grad_in_cur_loop = grad_in + nram_pts_idx_cur_voxel[k] * channels + + nram_channels_limit * channels_loop_idx; + __bang_atomic_add((T *)nram_grad_in_cur_loop, (T *)grad_in_cur_loop, + (T *)nram_grad_out_cur_loop, actual_channels_num); + } + } + } +} + +void KernelRoiawarePool3dBackward( + cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, + const cnrtDataType_t d_type, const int pool_method, const int boxes_num, + const int out_x, const int out_y, const int out_z, const int channels, + const int max_pts_each_voxel, const int *pts_idx_of_voxels, + const int *argmax, const void *grad_out, void *grad_in) { + if (pool_method == 0) { + switch (d_type) { + case CNRT_FLOAT32: { + MLUUnion1KernelRoiawareMaxPool3dBackward + <<>>(boxes_num, out_x, out_y, out_z, channels, + (int *)argmax, (float *)grad_out, + (float *)grad_in); + }; break; + case CNRT_FLOAT16: { + MLUUnion1KernelRoiawareMaxPool3dBackward + <<>>(boxes_num, out_x, out_y, out_z, channels, + (int *)argmax, (half *)grad_out, + (half *)grad_in); + }; break; + default: { + break; + } + } + } else { + switch (d_type) { + case CNRT_FLOAT32: { + MLUUnion1KernelRoiawareAvgPool3dBackward + <<>>( + boxes_num, out_x, out_y, out_z, channels, max_pts_each_voxel, + (int *)pts_idx_of_voxels, (float *)grad_out, (float *)grad_in); + }; break; + case CNRT_FLOAT16: { + MLUUnion1KernelRoiawareAvgPool3dBackward + <<>>( + boxes_num, out_x, out_y, out_z, channels, max_pts_each_voxel, + (int *)pts_idx_of_voxels, (half *)grad_out, (half *)grad_in); + }; break; + default: { + break; + } + } + } +} diff --git a/mmcv/ops/csrc/pytorch/mlu/roiaware_pool3d_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/roiaware_pool3d_mlu.cpp new file mode 100644 index 0000000000..62cb2dc62e --- /dev/null +++ b/mmcv/ops/csrc/pytorch/mlu/roiaware_pool3d_mlu.cpp @@ -0,0 +1,399 @@ +/************************************************************************* + * Copyright (C) 2022 by Cambricon. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include "pytorch_device_registry.hpp" +#include "pytorch_mlu_helper.hpp" + +void KernelPtsIdxOfVoxels(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, + cnrtQueue_t queue, const cnrtDataType_t d_type, + const int pool_method, const int boxes_num, + const int pts_num, const int max_pts_each_voxel, + const int out_x, const int out_y, const int out_z, + const void *rois, const void *pts, + int *pts_idx_of_voxels); + +void KernelRoiawarePool3dForward( + cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, + const cnrtDataType_t d_type, const int pool_method, const int boxes_num, + const int pts_num, const int channels, const int max_pts_each_voxel, + const int out_x, const int out_y, const int out_z, const void *pts_feature, + const int *pts_idx_of_voxels, void *pooled_features, int *argmax); + +// policy function +static void kernelPtsIdxOfVoxelsPolicyFunc(const int boxes_num, + cnrtDim3_t *k_dim, + cnrtFunctionType_t *k_type) { + unsigned int core_num = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); + unsigned int cluster_num = torch_mlu::getDeviceAttr(cnrtAttrClusterCount); + *k_type = CNRT_FUNC_TYPE_UNION1; + k_dim->x = core_num; + unsigned int use_cluster = (boxes_num + core_num - 1) / core_num; + k_dim->y = use_cluster > cluster_num ? cluster_num : use_cluster; + k_dim->z = 1; +} + +static void kernelRoiawarePool3dForwardPolicyFunc( + const int boxes_num, const int out_x, const int out_y, const int out_z, + cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type) { + unsigned int core_num = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); + unsigned int cluster_num = torch_mlu::getDeviceAttr(cnrtAttrClusterCount); + *k_type = CNRT_FUNC_TYPE_UNION1; + k_dim->x = core_num; + const int voxels_num = boxes_num * out_x * out_y * out_z; + unsigned int use_cluster = (voxels_num + core_num - 1) / core_num; + k_dim->y = use_cluster > cluster_num ? cluster_num : use_cluster; + k_dim->z = 1; +} + +void RoiawarePool3dForwardMLUKernelLauncher( + const int pool_method, const int boxes_num, const int pts_num, + const int channels, const int max_pts_each_voxel, const int out_x, + const int out_y, const int out_z, const Tensor rois, const Tensor pts, + const Tensor pts_feature, Tensor pts_idx_of_voxels, Tensor pooled_features, + Tensor argmax) { + // check datatype + TORCH_CHECK(((pts.scalar_type() == rois.scalar_type()) && + (pts_feature.scalar_type() == rois.scalar_type()) && + (pooled_features.scalar_type() == rois.scalar_type())), + "data types of rois, rois, pts_feature and pooled_features " + "should be the same, ", + "but now rois type is ", rois.scalar_type(), ", pts type is ", + pts.scalar_type(), ", pts_feature type is ", + pts_feature.scalar_type(), ", pooled_features type is ", + pooled_features.scalar_type(), "."); + TORCH_CHECK( + (rois.scalar_type() == at::kFloat || rois.scalar_type() == at::kHalf), + "rois type should be Float or Half, got ", rois.scalar_type(), "."); + TORCH_CHECK((pts_idx_of_voxels.scalar_type() == at::kInt), + "pts_idx_of_voxels type should be Int, got ", + pts_idx_of_voxels.scalar_type(), "."); + // check dim + TORCH_CHECK(rois.dim() == 2, "rois should be a 2D tensor, got ", rois.dim(), + "D."); + TORCH_CHECK(pts.dim() == 2, "pts should be a 2D tensor, got ", pts.dim(), + "D."); + TORCH_CHECK(pts_feature.dim() == 2, "pts_feature should be a 2D tensor, got ", + pts_feature.dim(), "D."); + TORCH_CHECK(pts_idx_of_voxels.dim() == 5, + "pts_idx_of_voxels should be a 5D tensor, got ", + pts_idx_of_voxels.dim(), "D."); + TORCH_CHECK(pooled_features.dim() == 5, + "pooled_features should be a 5D tensor, got ", + pooled_features.dim(), "D."); + // check shape + TORCH_CHECK(((rois.size(0) == boxes_num) && (rois.size(1) == 7)), + "the dimensions of rois should be (boxes_num, 7), ", "but got (", + rois.size(0), ", ", rois.size(1), ") ."); + TORCH_CHECK(((pts.size(0) == pts_num) && (pts.size(1) == 3)), + "the dimensions of pts should be (pts_num, 3), ", "but got (", + pts.size(0), ",", pts.size(1), ")."); + TORCH_CHECK( + ((pts_feature.size(0) == pts_num) && (pts_feature.size(1) == channels)), + "the dimensions of pts_feature should be (pts_num, channels), ", + "but got (", pts_feature.size(0), ",", pts_feature.size(1), ")."); + TORCH_CHECK(((pts_idx_of_voxels.size(0) == boxes_num) && + (pts_idx_of_voxels.size(1) == out_x) && + (pts_idx_of_voxels.size(2) == out_y) && + (pts_idx_of_voxels.size(3) == out_z) && + (pts_idx_of_voxels.size(4) == max_pts_each_voxel)), + "the dimensions of pts_idx_of_voxels should be (boxes_num, " + "out_x, out_y, out_z, max_pts_each_voxel), ", + "but got (", pts_idx_of_voxels.size(0), ",", + pts_idx_of_voxels.size(1), ",", pts_idx_of_voxels.size(2), ",", + pts_idx_of_voxels.size(3), ",", pts_idx_of_voxels.size(4), ")."); + TORCH_CHECK(((pooled_features.size(0) == boxes_num) && + (pooled_features.size(1) == out_x) && + (pooled_features.size(2) == out_y) && + (pooled_features.size(3) == out_z) && + (pooled_features.size(4) == channels)), + "the dimensions of pooled_features should be (boxes_num, out_x, " + "out_y, out_z, channels), ", + "but got (", pooled_features.size(0), ",", + pooled_features.size(1), ",", pooled_features.size(2), ",", + pooled_features.size(3), ",", pooled_features.size(4), ")."); + // check other params : pool_mothod + TORCH_CHECK(((pool_method == 0) || (pool_method == 1)), + "the num of pool_method should be 0(max) or 1(avg), ", "but got ", + pool_method, "."); + // check large tensor + const size_t max_input_size = 2147483648; + TORCH_CHECK(rois.numel() < max_input_size, + "rois element num should be less than 2^31, got ", rois.numel(), + "."); + TORCH_CHECK(pts.numel() < max_input_size, + "pts element num should be less than 2^31, got ", pts.numel(), + "."); + TORCH_CHECK(pts_feature.numel() < max_input_size, + "pts_feature element num should be less than 2^31, got ", + pts_feature.numel(), "."); + TORCH_CHECK(pts_idx_of_voxels.numel() < max_input_size, + "pts_idx_of_voxels element num should be less than 2^31, got ", + pts_idx_of_voxels.numel(), "."); + TORCH_CHECK(pooled_features.numel() < max_input_size, + "pooled_features element num should be less than 2^31, got ", + pooled_features.numel(), "."); + // check zero element + TORCH_CHECK(rois.numel() != 0, "rois.numel() should not be zero, got ", + rois.numel()); + TORCH_CHECK(pts.numel() != 0, "pts.numel() should not be zero, got ", + pts.numel()); + TORCH_CHECK(pts_feature.numel() != 0, + "pts_feature.numel() should not be zero, got ", + pts_feature.numel()); + TORCH_CHECK(pts_idx_of_voxels.numel() != 0, + "pts_idx_of_voxels.numel() should not be zero, got ", + pts_idx_of_voxels.numel()); + TORCH_CHECK(pooled_features.numel() != 0, + "pooled_features.numel() should not be zero, got ", + pooled_features.numel()); + if (pool_method == 0) { + // check datatype + TORCH_CHECK((argmax.scalar_type() == at::kInt), + "argmax type should be Int, got ", argmax.scalar_type(), "."); + // check dim + TORCH_CHECK(argmax.dim() == 5, "argmax should be a 5D tensor, got ", + argmax.dim(), "D."); + // check shape + TORCH_CHECK(((argmax.size(0) == boxes_num) && (argmax.size(1) == out_x) && + (argmax.size(2) == out_y) && (argmax.size(3) == out_z) && + (argmax.size(4) == channels)), + "the dimensions of argmax should be (boxes_num, out_x, out_y, " + "out_z, channels), ", + "but got (", argmax.size(0), ",", argmax.size(1), ",", + argmax.size(2), ",", argmax.size(3), ",", argmax.size(4), ")."); + // check large tensor + TORCH_CHECK(argmax.numel() < max_input_size, + "argmax element num should be less than 2^31, got ", + argmax.numel(), "."); + // check zero element + TORCH_CHECK(argmax.numel() != 0, "argmax.numel() should not be zero, got ", + argmax.numel()); + // when pool_method is 0, which is max pool, init argmax data value to -1 + argmax.fill_(static_cast(-1)); + } + // calculate task one dimension + cnrtDim3_t k1_dim; + cnrtFunctionType_t k1_type; + kernelPtsIdxOfVoxelsPolicyFunc(boxes_num, &k1_dim, &k1_type); + cnrtDim3_t k2_dim; + cnrtFunctionType_t k2_type; + kernelRoiawarePool3dForwardPolicyFunc(boxes_num, out_x, out_y, out_z, &k2_dim, + &k2_type); + // get compute queue + auto queue = torch_mlu::getCurQueue(); + // get ptr of tensors + auto rois_impl = torch_mlu::getMluTensorImpl(rois); + auto rois_ptr = rois_impl->cnnlMalloc(); + // transpose points [pts_num, 3] -> [3, pts_num] + auto pts_ = pts.permute({1, 0}).contiguous(); + auto pts_impl = torch_mlu::getMluTensorImpl(pts_); + auto pts_ptr = pts_impl->cnnlMalloc(); + // transpose points_features [pts_num, channels] -> [channels, pts_num] + auto pts_feature_ = pts_feature.permute({1, 0}).contiguous(); + auto pts_feature_impl = torch_mlu::getMluTensorImpl(pts_feature_); + auto pts_feature_ptr = pts_feature_impl->cnnlMalloc(); + auto pts_idx_of_voxels_impl = torch_mlu::getMluTensorImpl(pts_idx_of_voxels); + auto pts_idx_of_voxels_ptr = pts_idx_of_voxels_impl->cnnlMalloc(); + auto pooled_features_impl = torch_mlu::getMluTensorImpl(pooled_features); + auto pooled_features_ptr = pooled_features_impl->cnnlMalloc(); + auto argmax_impl = torch_mlu::getMluTensorImpl(argmax); + auto argmax_ptr = argmax_impl->cnnlMalloc(); + // get compute dtype of input + cnrtDataType_t data_type = torch_mlu::toCnrtDtype(rois.dtype()); + // launch kernel PtsIdxOfVoxels + CNLOG(INFO) << "Launch Kernel MLUKernel PtsIdxOfVoxels<<<" << k1_dim.x << ", " + << k1_dim.y << ", " << k1_dim.z << ">>>"; + KernelPtsIdxOfVoxels(k1_dim, k1_type, queue, data_type, pool_method, + boxes_num, pts_num, max_pts_each_voxel, out_x, out_y, + out_z, rois_ptr, pts_ptr, (int *)pts_idx_of_voxels_ptr); + // launch kernel RoiawarePool3dForward + CNLOG(INFO) << "Launch Kernel MLUKernel RoiawarePool3dForward<<<" << k2_dim.x + << ", " << k2_dim.y << ", " << k2_dim.z << ">>>"; + KernelRoiawarePool3dForward( + k2_dim, k2_type, queue, data_type, pool_method, boxes_num, pts_num, + channels, max_pts_each_voxel, out_x, out_y, out_z, pts_feature_ptr, + (int *)pts_idx_of_voxels_ptr, pooled_features_ptr, (int *)argmax_ptr); +} + +void roiaware_pool3d_forward_mlu(int boxes_num, int pts_num, int channels, + int max_pts_each_voxel, int out_x, int out_y, + int out_z, const Tensor rois, const Tensor pts, + const Tensor pts_feature, Tensor argmax, + Tensor pts_idx_of_voxels, + Tensor pooled_features, int pool_method) { + RoiawarePool3dForwardMLUKernelLauncher( + pool_method, boxes_num, pts_num, channels, max_pts_each_voxel, out_x, + out_y, out_z, rois, pts, pts_feature, pts_idx_of_voxels, pooled_features, + argmax); +} + +void roiaware_pool3d_forward_impl(int boxes_num, int pts_num, int channels, + int max_pts_each_voxel, int out_x, int out_y, + int out_z, const Tensor rois, + const Tensor pts, const Tensor pts_feature, + Tensor argmax, Tensor pts_idx_of_voxels, + Tensor pooled_features, int pool_method); + +REGISTER_DEVICE_IMPL(roiaware_pool3d_forward_impl, MLU, + roiaware_pool3d_forward_mlu); + +void KernelRoiawarePool3dBackward( + cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, + const cnrtDataType_t d_type, const int pool_method, const int boxes_num, + const int out_x, const int out_y, const int out_z, const int channels, + const int max_pts_each_voxel, const int *pts_idx_of_voxels, + const int *argmax, const void *grad_out, void *grad_in); + +static void kernelRoiawarePool3dBackwardPolicyFunc( + const int boxes_num, const int out_x, const int out_y, const int out_z, + cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type) { + unsigned int core_num = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); + unsigned int cluster_num = torch_mlu::getDeviceAttr(cnrtAttrClusterCount); + *k_type = CNRT_FUNC_TYPE_UNION1; + k_dim->x = core_num; + const int voxels_num = boxes_num * out_x * out_y * out_z; + unsigned int use_cluster = (voxels_num + core_num - 1) / core_num; + k_dim->y = use_cluster > cluster_num ? cluster_num : use_cluster; + k_dim->z = 1; +} + +void RoiawarePool3dBackwardMLUKernelLauncher( + int pool_method, int boxes_num, int out_x, int out_y, int out_z, + int channels, int max_pts_each_voxel, const Tensor pts_idx_of_voxels, + const Tensor argmax, const Tensor grad_out, Tensor grad_in) { + // check datatype + TORCH_CHECK((pts_idx_of_voxels.scalar_type() == at::kInt), + "pts_idx_of_voxels type should be Int, got ", + pts_idx_of_voxels.scalar_type(), "."); + TORCH_CHECK((argmax.scalar_type() == at::kInt), + "argmax type should be Int, got ", argmax.scalar_type(), "."); + TORCH_CHECK((grad_out.scalar_type() == at::kFloat || + grad_out.scalar_type() == at::kHalf), + "grad_out type should be Float or Half, got ", + grad_out.scalar_type(), "."); + TORCH_CHECK((grad_out.scalar_type() == grad_in.scalar_type()), + "data types of grad_out, grad_in, should be the same, ", + "but now grad_out type is ", grad_out.scalar_type(), + ", grad_in type is ", grad_in.scalar_type(), "."); + // check dim + TORCH_CHECK(pts_idx_of_voxels.dim() == 5, + "pts_idx_of_voxels should be a 5D tensor, got ", + pts_idx_of_voxels.dim(), "D."); + TORCH_CHECK(argmax.dim() == 5, "argmax should be a 5D tensor, got ", + argmax.dim(), "D."); + TORCH_CHECK(grad_out.dim() == 5, "grad_out should be a 5D tensor, got ", + grad_out.dim(), "D."); + TORCH_CHECK(grad_in.dim() == 2, "grad_in should be a 2D tensor, got ", + grad_in.dim(), "D."); + // check shape + TORCH_CHECK(((pts_idx_of_voxels.size(0) == boxes_num) && + (pts_idx_of_voxels.size(1) == out_x) && + (pts_idx_of_voxels.size(2) == out_y) && + (pts_idx_of_voxels.size(3) == out_z) && + (pts_idx_of_voxels.size(4) == max_pts_each_voxel)), + "the dimensions of pts_idx_of_voxels should be (boxes_num, " + "out_x, out_y, out_z, max_pts_each_voxel), ", + "but got (", pts_idx_of_voxels.size(0), ",", + pts_idx_of_voxels.size(1), ",", pts_idx_of_voxels.size(2), ",", + pts_idx_of_voxels.size(3), ",", pts_idx_of_voxels.size(4), ")."); + TORCH_CHECK(((argmax.size(0) == boxes_num) && (argmax.size(1) == out_x) && + (argmax.size(2) == out_y) && (argmax.size(3) == out_z) && + (argmax.size(4) == channels)), + "the dimensions of argmax should be (boxes_num, out_x, out_y, " + "out_z, channels), ", + "but got (", argmax.size(0), ",", argmax.size(1), ",", + argmax.size(2), ",", argmax.size(3), ",", argmax.size(4), ")."); + TORCH_CHECK(((grad_out.size(0) == boxes_num) && (grad_out.size(1) == out_x) && + (grad_out.size(2) == out_y) && (grad_out.size(3) == out_z) && + (grad_out.size(4) == channels)), + "the dimensions of grad_out should be (boxes_num, out_x, " + "out_y, out_z, channels), ", + "but got (", grad_out.size(0), ",", grad_out.size(1), ",", + grad_out.size(2), ",", grad_out.size(3), ",", grad_out.size(4), + ")."); + TORCH_CHECK((grad_in.size(1) == channels), + "the 1st dimensions of grad_in should be channels, ", "but got ", + grad_in.size(1), "."); + // check other params : pool_mothod + TORCH_CHECK(((pool_method == 0) || (pool_method == 1)), + "the num of pool_method should be 0(max) or 1(avg), ", "but got ", + pool_method, "."); + // check large tensor + const size_t max_input_size = 2147483648; + TORCH_CHECK(pts_idx_of_voxels.numel() < max_input_size, + "pts_idx_of_voxels element num should be less than 2^31, got ", + pts_idx_of_voxels.numel(), "."); + TORCH_CHECK(argmax.numel() < max_input_size, + "argmax element num should be less than 2^31, got ", + argmax.numel(), "."); + TORCH_CHECK(grad_out.numel() < max_input_size, + "grad_out element num should be less than 2^31, got ", + grad_out.numel(), "."); + TORCH_CHECK(grad_in.numel() < max_input_size, + "grad_in element num should be less than 2^31, got ", + grad_in.numel(), "."); + // check zero element + TORCH_CHECK(pts_idx_of_voxels.numel() != 0, + "pts_idx_of_voxels.numel() should not be zero, got ", + pts_idx_of_voxels.numel()); + TORCH_CHECK(argmax.numel() != 0, "argmax.numel() should not be zero, got ", + argmax.numel()); + TORCH_CHECK(grad_out.numel() != 0, + "grad_out.numel() should not be zero, got ", grad_out.numel()); + TORCH_CHECK(grad_in.numel() != 0, "grad_in.numel() should not be zero, got ", + grad_in.numel()); + // calculate task one dimension + cnrtDim3_t k_dim; + cnrtFunctionType_t k_type; + kernelRoiawarePool3dBackwardPolicyFunc(boxes_num, out_x, out_y, out_z, &k_dim, + &k_type); + // get compute queue + auto queue = torch_mlu::getCurQueue(); + // transpose points_features [pts_num, channels] -> [channels, pts_num] + auto pts_idx_of_voxels_impl = torch_mlu::getMluTensorImpl(pts_idx_of_voxels); + auto pts_idx_of_voxels_ptr = pts_idx_of_voxels_impl->cnnlMalloc(); + auto argmax_impl = torch_mlu::getMluTensorImpl(argmax); + auto argmax_ptr = argmax_impl->cnnlMalloc(); + auto grad_out_impl = torch_mlu::getMluTensorImpl(grad_out); + auto grad_out_ptr = grad_out_impl->cnnlMalloc(); + auto grad_in_impl = torch_mlu::getMluTensorImpl(grad_in); + auto grad_in_ptr = grad_in_impl->cnnlMalloc(); + // get compute dtype of input + cnrtDataType_t data_type = torch_mlu::toCnrtDtype(grad_out.dtype()); + // launch kernel RoiawarePool3dForward + CNLOG(INFO) << "Launch Kernel MLUKernel RoiawarePool3dBackward<<<" << k_dim.x + << ", " << k_dim.y << ", " << k_dim.z << ">>>"; + KernelRoiawarePool3dBackward(k_dim, k_type, queue, data_type, pool_method, + boxes_num, out_x, out_y, out_z, channels, + max_pts_each_voxel, (int *)pts_idx_of_voxels_ptr, + (int *)argmax_ptr, grad_out_ptr, grad_in_ptr); +} + +void roiaware_pool3d_backward_mlu(int boxes_num, int out_x, int out_y, + int out_z, int channels, + int max_pts_each_voxel, + const Tensor pts_idx_of_voxels, + const Tensor argmax, const Tensor grad_out, + Tensor grad_in, int pool_method) { + RoiawarePool3dBackwardMLUKernelLauncher( + pool_method, boxes_num, out_x, out_y, out_z, channels, max_pts_each_voxel, + pts_idx_of_voxels, argmax, grad_out, grad_in); +} + +void roiaware_pool3d_backward_impl(int boxes_num, int out_x, int out_y, + int out_z, int channels, + int max_pts_each_voxel, + const Tensor pts_idx_of_voxels, + const Tensor argmax, const Tensor grad_out, + Tensor grad_in, int pool_method); + +REGISTER_DEVICE_IMPL(roiaware_pool3d_backward_impl, MLU, + roiaware_pool3d_backward_mlu); diff --git a/tests/test_ops/test_roiaware_pool3d.py b/tests/test_ops/test_roiaware_pool3d.py index 7975448f8e..5391e924db 100644 --- a/tests/test_ops/test_roiaware_pool3d.py +++ b/tests/test_ops/test_roiaware_pool3d.py @@ -5,11 +5,27 @@ from mmcv.ops import (RoIAwarePool3d, points_in_boxes_all, points_in_boxes_cpu, points_in_boxes_part) - - -@pytest.mark.skipif( - not torch.cuda.is_available(), reason='requires CUDA support') -def test_RoIAwarePool3d(): +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE + + +@pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'mlu', + marks=pytest.mark.skipif( + not IS_MLU_AVAILABLE, reason='requires MLU support')) +]) +@pytest.mark.parametrize('dtype', [ + torch.float, torch.half, + pytest.param( + torch.double, + marks=pytest.mark.skipif( + IS_MLU_AVAILABLE, reason='MLU does not support for double')) +]) +def test_RoIAwarePool3d(device, dtype): roiaware_pool3d_max = RoIAwarePool3d( out_size=4, max_pts_per_voxel=128, mode='max') roiaware_pool3d_avg = RoIAwarePool3d( @@ -17,27 +33,27 @@ def test_RoIAwarePool3d(): rois = torch.tensor( [[1.0, 2.0, 3.0, 5.0, 4.0, 6.0, -0.3 - np.pi / 2], [-10.0, 23.0, 16.0, 20.0, 10.0, 20.0, -0.5 - np.pi / 2]], - dtype=torch.float32).cuda( - ) # boxes (m, 7) with bottom center in lidar coordinate + dtype=dtype).to(device) + # boxes (m, 7) with bottom center in lidar coordinate pts = torch.tensor( [[1, 2, 3.3], [1.2, 2.5, 3.0], [0.8, 2.1, 3.5], [1.6, 2.6, 3.6], [0.8, 1.2, 3.9], [-9.2, 21.0, 18.2], [3.8, 7.9, 6.3], [4.7, 3.5, -12.2], [3.8, 7.6, -2], [-10.6, -12.9, -20], [-16, -18, 9], [-21.3, -52, -5], [0, 0, 0], [6, 7, 8], [-2, -3, -4]], - dtype=torch.float32).cuda() # points (n, 3) in lidar coordinate + dtype=dtype).to(device) # points (n, 3) in lidar coordinate pts_feature = pts.clone() pooled_features_max = roiaware_pool3d_max( rois=rois, pts=pts, pts_feature=pts_feature) assert pooled_features_max.shape == torch.Size([2, 4, 4, 4, 3]) assert torch.allclose(pooled_features_max.sum(), - torch.tensor(51.100).cuda(), 1e-3) + torch.tensor(51.100, dtype=dtype).to(device), 1e-3) pooled_features_avg = roiaware_pool3d_avg( rois=rois, pts=pts, pts_feature=pts_feature) assert pooled_features_avg.shape == torch.Size([2, 4, 4, 4, 3]) assert torch.allclose(pooled_features_avg.sum(), - torch.tensor(49.750).cuda(), 1e-3) + torch.tensor(49.750, dtype=dtype).to(device), 1e-3) @pytest.mark.skipif(