diff --git a/mmcv/ops/csrc/pytorch/mlu/ball_query_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/ball_query_mlu.cpp index 000f8882b1..f57b0e88fb 100644 --- a/mmcv/ops/csrc/pytorch/mlu/ball_query_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/ball_query_mlu.cpp @@ -34,9 +34,9 @@ void ball_query_forward_mlu(int b, int n, int m, float min_radius, auto idx_ptr = idx_impl->cnnlMalloc(); auto handle = mluOpGetCurrentHandle(); - mluOpBallQuery(handle, new_xyz_desc.desc(), new_xyz_ptr, xyz_desc.desc(), - xyz_ptr, min_radius, max_radius, nsample, idx_desc.desc(), - idx_ptr); + TORCH_MLUOP_CHECK(mluOpBallQuery(handle, new_xyz_desc.desc(), new_xyz_ptr, xyz_desc.desc(), + xyz_ptr, min_radius, max_radius, nsample, idx_desc.desc(), + idx_ptr)); } void ball_query_forward_impl(int b, int n, int m, float min_radius, diff --git a/mmcv/ops/csrc/pytorch/mlu/box_iou_rotated.cpp b/mmcv/ops/csrc/pytorch/mlu/box_iou_rotated.cpp index 6a903973d0..15fd47ace8 100644 --- a/mmcv/ops/csrc/pytorch/mlu/box_iou_rotated.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/box_iou_rotated.cpp @@ -38,9 +38,9 @@ void BoxIouRotatedMLUKernelLauncher(const Tensor boxes1, const Tensor boxes2, auto ious_ptr = ious_impl->cnnlMalloc(); CNLOG(INFO) << "Call mluOpBoxIouRotated()."; - mluOpBoxIouRotated(handle, mode_flag, aligned, boxes1_desc.desc(), boxes1_ptr, - boxes2_desc.desc(), boxes2_ptr, ious_desc.desc(), - ious_ptr); + TORCH_MLUOP_CHECK(mluOpBoxIouRotated(handle, mode_flag, aligned, boxes1_desc.desc(), boxes1_ptr, + boxes2_desc.desc(), boxes2_ptr, ious_desc.desc(), + ious_ptr)); } void box_iou_rotated_mlu(const Tensor boxes1, const Tensor boxes2, Tensor ious, diff --git a/mmcv/ops/csrc/pytorch/mlu/carafe_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/carafe_mlu.cpp index 5a7d6c7e39..a78c7cfda3 100644 --- a/mmcv/ops/csrc/pytorch/mlu/carafe_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/carafe_mlu.cpp @@ -71,15 +71,15 @@ void CARAFEForwardMLUKernelLauncher(const Tensor input, const Tensor mask, // set op descriptor auto handle = mluOpGetCurrentHandle(); mluOpCarafeDescriptor_t carafe_desc; - mluOpCreateCarafeDescriptor(&carafe_desc); - mluOpSetCarafeDescriptor(carafe_desc, input.dim(), kernel_size, group_size, - scale_factor); + TORCH_MLUOP_CHECK(mluOpCreateCarafeDescriptor(&carafe_desc)); + TORCH_MLUOP_CHECK(mluOpSetCarafeDescriptor(carafe_desc, input.dim(), kernel_size, group_size, + scale_factor)); // launch kernel - mluOpCarafeForward(handle, carafe_desc, input_desc.desc(), input_ptr, - mask_desc.desc(), mask_ptr, output_desc.desc(), - output_ptr); + TORCH_MLUOP_CHECK(mluOpCarafeForward(handle, carafe_desc, input_desc.desc(), input_ptr, + mask_desc.desc(), mask_ptr, output_desc.desc(), + output_ptr)); // destroy op descriptor - mluOpDestroyCarafeDescriptor(carafe_desc); + TORCH_MLUOP_CHECK(mluOpDestroyCarafeDescriptor(carafe_desc)); // copy output from NHWC back into NCHW rinput.copy_(rinput_); @@ -159,16 +159,16 @@ void CARAFEBackwardMLUKernelLauncher( // set op descriptor auto handle = mluOpGetCurrentHandle(); mluOpCarafeDescriptor_t carafe_desc; - mluOpCreateCarafeDescriptor(&carafe_desc); - mluOpSetCarafeDescriptor(carafe_desc, grad_output.dim(), kernel_size, - group_size, scale_factor); + TORCH_MLUOP_CHECK(mluOpCreateCarafeDescriptor(&carafe_desc)); + TORCH_MLUOP_CHECK(mluOpSetCarafeDescriptor(carafe_desc, grad_output.dim(), kernel_size, + group_size, scale_factor)); // launch kernel - mluOpCarafeBackward(handle, carafe_desc, input_desc.desc(), input_ptr, - mask_desc.desc(), mask_ptr, grad_output_desc.desc(), - grad_output_ptr, grad_input_desc.desc(), grad_input_ptr, - grad_mask_desc.desc(), grad_mask_ptr); + TORCH_MLUOP_CHECK(mluOpCarafeBackward(handle, carafe_desc, input_desc.desc(), input_ptr, + mask_desc.desc(), mask_ptr, grad_output_desc.desc(), + grad_output_ptr, grad_input_desc.desc(), grad_input_ptr, + grad_mask_desc.desc(), grad_mask_ptr)); // destroy op descriptor - mluOpDestroyCarafeDescriptor(carafe_desc); + TORCH_MLUOP_CHECK(mluOpDestroyCarafeDescriptor(carafe_desc)); // copy output from NHWC back into NCHW grad_input.copy_(rgrad_input_); diff --git a/mmcv/ops/csrc/pytorch/mlu/deform_roi_pool_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/deform_roi_pool_mlu.cpp index 90a625c4a2..dd6925e168 100644 --- a/mmcv/ops/csrc/pytorch/mlu/deform_roi_pool_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/deform_roi_pool_mlu.cpp @@ -50,10 +50,10 @@ void DeformRoIPoolForwardMLUKernelLauncher(Tensor input, Tensor rois, // get compute handle auto handle = mluOpGetCurrentHandle(); - mluOpDeformRoiPoolForward( - handle, input_desc.desc(), input_ptr, rois_desc.desc(), rois_ptr, - offset_real_desc, offset_ptr, pooled_height, pooled_width, spatial_scale, - sampling_ratio, gamma, output_desc.desc(), output_ptr); + TORCH_MLUOP_CHECK(mluOpDeformRoiPoolForward( + handle, input_desc.desc(), input_ptr, rois_desc.desc(), rois_ptr, + offset_real_desc, offset_ptr, pooled_height, pooled_width, spatial_scale, + sampling_ratio, gamma, output_desc.desc(), output_ptr)); output.copy_(output_contiguous); } @@ -113,12 +113,12 @@ void DeformRoIPoolBackwardMLUKernelLauncher( // get compute handle auto handle = mluOpGetCurrentHandle(); - mluOpDeformRoiPoolBackward( - handle, grad_output_desc.desc(), grad_output_ptr, input_desc.desc(), - input_ptr, rois_desc.desc(), rois_ptr, offset_real_desc, offset_ptr, - pooled_height, pooled_width, spatial_scale, sampling_ratio, gamma, - grad_input_desc.desc(), grad_input_ptr, grad_offset_real_desc, - grad_offset_ptr); + TORCH_MLUOP_CHECK(mluOpDeformRoiPoolBackward( + handle, grad_output_desc.desc(), grad_output_ptr, input_desc.desc(), + input_ptr, rois_desc.desc(), rois_ptr, offset_real_desc, offset_ptr, + pooled_height, pooled_width, spatial_scale, sampling_ratio, gamma, + grad_input_desc.desc(), grad_input_ptr, grad_offset_real_desc, + grad_offset_ptr)); grad_input.copy_(grad_input_); } diff --git a/mmcv/ops/csrc/pytorch/mlu/diff_iou_rotated_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/diff_iou_rotated_mlu.cpp index d50bddca55..10267b8bdd 100644 --- a/mmcv/ops/csrc/pytorch/mlu/diff_iou_rotated_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/diff_iou_rotated_mlu.cpp @@ -42,9 +42,9 @@ Tensor diff_iou_rotated_sort_vertices_forward_mlu(Tensor vertices, Tensor mask, auto handle = mluOpGetCurrentHandle(); // launch kernel - mluOpDiffIouRotatedSortVerticesForward( - handle, vertices_desc.desc(), vertices_ptr, mask_desc.desc(), mask_ptr, - num_valid_desc.desc(), num_valid_ptr, idx_desc.desc(), idx_ptr); + TORCH_MLUOP_CHECK(mluOpDiffIouRotatedSortVerticesForward( + handle, vertices_desc.desc(), vertices_ptr, mask_desc.desc(), mask_ptr, + num_valid_desc.desc(), num_valid_ptr, idx_desc.desc(), idx_ptr)); return idx; } diff --git a/mmcv/ops/csrc/pytorch/mlu/iou3d_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/iou3d_mlu.cpp index 993aa5e410..43abebc5bb 100644 --- a/mmcv/ops/csrc/pytorch/mlu/iou3d_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/iou3d_mlu.cpp @@ -30,7 +30,7 @@ void IoU3DNMS3DMLUKernelLauncher(Tensor boxes, Tensor &keep, Tensor &keep_num, // workspace size_t workspace_size = 0; auto handle = mluOpGetCurrentHandle(); - mluOpGetNmsWorkspaceSize(handle, boxes_desc.desc(), NULL, &workspace_size); + TORCH_MLUOP_CHECK(mluOpGetNmsWorkspaceSize(handle, boxes_desc.desc(), NULL, &workspace_size)); auto workspace = at::empty(workspace_size, boxes.options().dtype(at::kByte)); // get compute queue @@ -56,16 +56,16 @@ void IoU3DNMS3DMLUKernelLauncher(Tensor boxes, Tensor &keep, Tensor &keep_num, const int max_output_size = input_box_num; const float offset = 0.0; - mluOpCreateNmsDescriptor(&nms_desc); - mluOpSetNmsDescriptor(nms_desc, box_mode, output_mode, algo, method_mode, - iou_threshold, soft_nms_sigma, max_output_size, - confidence_threshold, offset, input_layout, - pad_to_max_output_size); + TORCH_MLUOP_CHECK(mluOpCreateNmsDescriptor(&nms_desc)); + TORCH_MLUOP_CHECK(mluOpSetNmsDescriptor(nms_desc, box_mode, output_mode, algo, method_mode, + iou_threshold, soft_nms_sigma, max_output_size, + confidence_threshold, offset, input_layout, + pad_to_max_output_size)); - mluOpNms(handle, nms_desc, boxes_desc.desc(), boxes_ptr, NULL, NULL, - workspace_ptr, workspace_size, output_desc.desc(), output_ptr, - output_size_ptr); - mluOpDestroyNmsDescriptor(nms_desc); + TORCH_MLUOP_CHECK(mluOpNms(handle, nms_desc, boxes_desc.desc(), boxes_ptr, NULL, NULL, + workspace_ptr, workspace_size, output_desc.desc(), output_ptr, + output_size_ptr)); + TORCH_MLUOP_CHECK(mluOpDestroyNmsDescriptor(nms_desc)); } void iou3d_nms3d_forward_mlu(const Tensor boxes, Tensor &keep, Tensor &keep_num, diff --git a/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.cpp b/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.cpp index 2799d3aa1b..35395c5ba9 100644 --- a/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.cpp @@ -123,7 +123,7 @@ void MluOpTensorDescriptor::set_desc(const at::Tensor& t, mluOpDataType_t dtype, std::vector& dims) { int dimNb = dims.size(); - mluOpSetTensorDescriptor(desc_, layout, dtype, dimNb, dims.data()); + TORCH_MLUOP_CHECK(mluOpSetTensorDescriptor(desc_, layout, dtype, dimNb, dims.data())); } // Handles diff --git a/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h b/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h index 8743f2b333..806cc2acc8 100644 --- a/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h +++ b/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h @@ -34,6 +34,17 @@ auto NAME##_impl = torch_mlu::getMluTensorImpl(NAME##_contigous); \ auto NAME##_ptr = NAME##_impl->cnnlMalloc(); +#ifndef TORCH_MLUOP_CHECK +#define TORCH_MLUOP_CHECK(EXPR) \ + do { \ + mluOpStatus_t status = EXPR; \ + if (status != MLUOP_STATUS_SUCCESS) { \ + CNLOG(ERROR) << ""; \ + TORCH_CHECK(false, "MLUOPS error: ", mluOpGetErrorString(status)); \ + } \ + } while (0); +#endif + enum class reduce_t { SUM = 0, MEAN = 1, MAX = 2 }; inline std::string to_string(reduce_t reduce_type) { @@ -54,8 +65,8 @@ mluOpReduceMode_t getMluOpReduceMode(const reduce_t reduce_type); class MluOpTensorDescriptor { public: - MluOpTensorDescriptor() { mluOpCreateTensorDescriptor(&desc_); }; - ~MluOpTensorDescriptor() { mluOpDestroyTensorDescriptor(desc_); } + MluOpTensorDescriptor() { TORCH_MLUOP_CHECK(mluOpCreateTensorDescriptor(&desc_)); }; + ~MluOpTensorDescriptor() { TORCH_MLUOP_CHECK(mluOpDestroyTensorDescriptor(desc_)); } void set(at::Tensor); void set_with_layout(at::Tensor, mluOpTensorLayout_t layout); @@ -71,14 +82,14 @@ mluOpHandle_t mluOpGetCurrentHandle(c10::DeviceIndex device_index = -1); class MluOpHandle { public: - MluOpHandle() : handle(nullptr) { mluOpCreate(&handle); } + MluOpHandle() : handle(nullptr) { TORCH_MLUOP_CHECK(mluOpCreate(&handle)); } ~MluOpHandle() { if (handle) { - mluOpDestroy(handle); + TORCH_MLUOP_CHECK(mluOpDestroy(handle)); handle = nullptr; } } - void setQueue(cnrtQueue_t queue) { mluOpSetQueue(handle, queue); } + void setQueue(cnrtQueue_t queue) { TORCH_MLUOP_CHECK(mluOpSetQueue(handle, queue)); } mluOpHandle_t handle; }; diff --git a/mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp index 2643bc537e..2b46abfdd3 100644 --- a/mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp @@ -35,12 +35,12 @@ Tensor MsDeformAttnForwardLauncher(const Tensor& value, INITIAL_MLU_PARAM_WITH_TENSOR(sampling_loc); INITIAL_MLU_PARAM_WITH_TENSOR(attn_weight); - mluOpMsDeformAttnForward( - handle, value_desc.desc(), value_ptr, spatial_shapes_int_desc.desc(), - spatial_shapes_int_ptr, level_start_index_int_desc.desc(), - level_start_index_int_ptr, sampling_loc_desc.desc(), sampling_loc_ptr, - attn_weight_desc.desc(), attn_weight_ptr, im2col_step, output_desc.desc(), - output_ptr); + TORCH_MLUOP_CHECK(mluOpMsDeformAttnForward( + handle, value_desc.desc(), value_ptr, spatial_shapes_int_desc.desc(), + spatial_shapes_int_ptr, level_start_index_int_desc.desc(), + level_start_index_int_ptr, sampling_loc_desc.desc(), sampling_loc_ptr, + attn_weight_desc.desc(), attn_weight_ptr, im2col_step, output_desc.desc(), + output_ptr)); output = output.view({batch_size, num_queries, num_heads * channels}); return output; diff --git a/mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp index eff6793f2d..32353fa07a 100644 --- a/mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp @@ -34,8 +34,8 @@ Tensor NMSMLUKernelLauncher(Tensor boxes, Tensor scores, float iou_threshold, // workspace size_t workspace_size = 0; auto handle = mluOpGetCurrentHandle(); - mluOpGetNmsWorkspaceSize(handle, boxes_desc.desc(), scores_desc.desc(), - &workspace_size); + TORCH_MLUOP_CHECK(mluOpGetNmsWorkspaceSize(handle, boxes_desc.desc(), scores_desc.desc(), + &workspace_size)); auto workspace = at::empty(workspace_size, boxes.options().dtype(at::kByte)); // get compute queue @@ -62,16 +62,16 @@ Tensor NMSMLUKernelLauncher(Tensor boxes, Tensor scores, float iou_threshold, const bool pad_to_max_output_size = false; const int max_output_size = max_output_boxes; - mluOpCreateNmsDescriptor(&nms_desc); - mluOpSetNmsDescriptor(nms_desc, box_mode, output_mode, algo, method_mode, - iou_threshold, soft_nms_sigma, max_output_size, - confidence_threshold, (float)offset, input_layout, - pad_to_max_output_size); + TORCH_MLUOP_CHECK(mluOpCreateNmsDescriptor(&nms_desc)); + TORCH_MLUOP_CHECK(mluOpSetNmsDescriptor(nms_desc, box_mode, output_mode, algo, method_mode, + iou_threshold, soft_nms_sigma, max_output_size, + confidence_threshold, (float)offset, input_layout, + pad_to_max_output_size)); - mluOpNms(handle, nms_desc, boxes_desc.desc(), boxes_ptr, scores_desc.desc(), - scores_ptr, workspace_ptr, workspace_size, output_desc.desc(), - output_ptr, output_size_ptr); - mluOpDestroyNmsDescriptor(nms_desc); + TORCH_MLUOP_CHECK(mluOpNms(handle, nms_desc, boxes_desc.desc(), boxes_ptr, scores_desc.desc(), + scores_ptr, workspace_ptr, workspace_size, output_desc.desc(), + output_ptr, output_size_ptr)); + TORCH_MLUOP_CHECK(mluOpDestroyNmsDescriptor(nms_desc)); int output_num = *static_cast(output_size.cpu().data_ptr()); auto ret = output.to(boxes.options().dtype(at::kLong)); return ret.slice(0, 0, output_num); diff --git a/mmcv/ops/csrc/pytorch/mlu/nms_rotated_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/nms_rotated_mlu.cpp index 9b45a17805..22b4f5f0f8 100644 --- a/mmcv/ops/csrc/pytorch/mlu/nms_rotated_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/nms_rotated_mlu.cpp @@ -30,7 +30,7 @@ Tensor nms_rotated_mlu(Tensor boxes, Tensor scores, float iou_threshold) { // workspace size_t workspace_size = 0; auto handle = mluOpGetCurrentHandle(); - mluOpGetNmsRotatedWorkspaceSize(handle, boxes_desc.desc(), &workspace_size); + TORCH_MLUOP_CHECK(mluOpGetNmsRotatedWorkspaceSize(handle, boxes_desc.desc(), &workspace_size)); auto workspace = at::empty(workspace_size, boxes.options().dtype(at::kByte)); auto boxes_impl = torch_mlu::getMluTensorImpl(boxes_); @@ -44,9 +44,9 @@ Tensor nms_rotated_mlu(Tensor boxes, Tensor scores, float iou_threshold) { auto output_size_impl = torch_mlu::getMluTensorImpl(output_size); auto output_size_ptr = output_size_impl->cnnlMalloc(); - mluOpNmsRotated(handle, iou_threshold, boxes_desc.desc(), boxes_ptr, - scores_desc.desc(), scores_ptr, workspace_ptr, workspace_size, - output_desc.desc(), output_ptr, (int *)output_size_ptr); + TORCH_MLUOP_CHECK(mluOpNmsRotated(handle, iou_threshold, boxes_desc.desc(), boxes_ptr, + scores_desc.desc(), scores_ptr, workspace_ptr, workspace_size, + output_desc.desc(), output_ptr, (int *)output_size_ptr)); int output_num = *static_cast(output_size.cpu().data_ptr()); auto ret = output.to(boxes.options().dtype(at::kLong)); return ret.slice(0, 0, output_num); diff --git a/mmcv/ops/csrc/pytorch/mlu/psamask_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/psamask_mlu.cpp index 3cce6c90a7..6c36bea598 100644 --- a/mmcv/ops/csrc/pytorch/mlu/psamask_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/psamask_mlu.cpp @@ -35,8 +35,8 @@ void PSAMaskForwardMLUKernelLauncher(const int psa_type, const Tensor x, auto y_impl = torch_mlu::getMluTensorImpl(y_tmp); auto y_ptr = y_impl->cnnlMalloc(); - mluOpPsamaskForward(handle, psa_type, x_desc.desc(), x_ptr, h_mask, w_mask, - y_desc.desc(), y_ptr); + TORCH_MLUOP_CHECK(mluOpPsamaskForward(handle, psa_type, x_desc.desc(), x_ptr, h_mask, w_mask, + y_desc.desc(), y_ptr)); y.copy_(y_tmp); } @@ -67,8 +67,8 @@ void PSAMaskBackwardMLUKernelLauncher(const int psa_type, const Tensor dy, auto dy_impl = torch_mlu::getMluTensorImpl(dy_tensor); auto dy_ptr = dy_impl->cnnlMalloc(); - mluOpPsamaskBackward(handle, psa_type, dy_desc.desc(), dy_ptr, h_mask, w_mask, - dx_tmp_desc.desc(), dx_ptr); + TORCH_MLUOP_CHECK(mluOpPsamaskBackward(handle, psa_type, dy_desc.desc(), dy_ptr, h_mask, w_mask, + dx_tmp_desc.desc(), dx_ptr)); dx.copy_(dx_tmp); } diff --git a/mmcv/ops/csrc/pytorch/mlu/roi_align_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/roi_align_mlu.cpp index ff6e5b1500..9c4b9a36c8 100644 --- a/mmcv/ops/csrc/pytorch/mlu/roi_align_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/roi_align_mlu.cpp @@ -48,10 +48,10 @@ void ROIAlignForwardMLUKernelLauncher(Tensor input, Tensor rois, Tensor output, auto output_ptr = output_impl->cnnlMalloc(); mluOpRoiAlignForwardDescriptor_t roialign_desc; - mluOpCreateRoiAlignForwardDescriptor(&roialign_desc); - mluOpSetRoiAlignForwardDescriptor_v2(roialign_desc, aligned_height, - aligned_width, sampling_ratio, - spatial_scale, pool_mode, aligned); + TORCH_MLUOP_CHECK(mluOpCreateRoiAlignForwardDescriptor(&roialign_desc)); + TORCH_MLUOP_CHECK(mluOpSetRoiAlignForwardDescriptor_v2(roialign_desc, aligned_height, + aligned_width, sampling_ratio, + spatial_scale, pool_mode, aligned)); auto handle = mluOpGetCurrentHandle(); if (pool_mode == 0) { @@ -65,18 +65,18 @@ void ROIAlignForwardMLUKernelLauncher(Tensor input, Tensor rois, Tensor output, auto argmax_y_ptr = argmax_y_impl->cnnlMalloc(); argmax_y_desc.set_with_layout(argmax_x_contiguous, MLUOP_LAYOUT_NHWC); argmax_x_desc.set_with_layout(argmax_x_contiguous, MLUOP_LAYOUT_NHWC); - mluOpRoiAlignForward_v2(handle, roialign_desc, input_desc.desc(), self_ptr, - rois_desc.desc(), rois_ptr, output_desc.desc(), - output_ptr, argmax_x_desc.desc(), argmax_x_ptr, - argmax_y_desc.desc(), argmax_y_ptr); + TORCH_MLUOP_CHECK(mluOpRoiAlignForward_v2(handle, roialign_desc, input_desc.desc(), self_ptr, + rois_desc.desc(), rois_ptr, output_desc.desc(), + output_ptr, argmax_x_desc.desc(), argmax_x_ptr, + argmax_y_desc.desc(), argmax_y_ptr); argmax_x.copy_(argmax_x_contiguous); argmax_y.copy_(argmax_y_contiguous); } else { - mluOpRoiAlignForward_v2(handle, roialign_desc, input_desc.desc(), self_ptr, - rois_desc.desc(), rois_ptr, output_desc.desc(), - output_ptr, NULL, NULL, NULL, NULL); + TORCH_MLUOP_CHECK(mluOpRoiAlignForward_v2(handle, roialign_desc, input_desc.desc(), self_ptr, + rois_desc.desc(), rois_ptr, output_desc.desc(), + output_ptr, NULL, NULL, NULL, NULL); } - mluOpDestroyRoiAlignForwardDescriptor(roialign_desc); + TORCH_MLUOP_CHECK(mluOpDestroyRoiAlignForwardDescriptor(roialign_desc)); output.copy_(output_contiguous); } @@ -136,16 +136,16 @@ void ROIAlignBackwardMLUKernelLauncher(Tensor grad, Tensor rois, auto argmax_y_ptr = argmax_y_impl->cnnlMalloc(); argmax_y_desc.set_with_layout(argmax_x_contiguous, MLUOP_LAYOUT_NHWC); argmax_x_desc.set_with_layout(argmax_x_contiguous, MLUOP_LAYOUT_NHWC); - mluOpRoiAlignBackward_v2(handle, grads_desc.desc(), grad_ptr, - rois_desc.desc(), rois_ptr, argmax_y_desc.desc(), - argmax_x_ptr, argmax_y_desc.desc(), argmax_y_ptr, - spatial_scale, sampling_ratio, aligned, pool_mode, - grad_input_desc.desc(), grad_input_ptr); + TORCH_MLUOP_CHECK(mluOpRoiAlignBackward_v2(handle, grads_desc.desc(), grad_ptr, + rois_desc.desc(), rois_ptr, argmax_y_desc.desc(), + argmax_x_ptr, argmax_y_desc.desc(), argmax_y_ptr, + spatial_scale, sampling_ratio, aligned, pool_mode, + grad_input_desc.desc(), grad_input_ptr)); } else { - mluOpRoiAlignBackward_v2(handle, grads_desc.desc(), grad_ptr, - rois_desc.desc(), rois_ptr, NULL, NULL, NULL, NULL, - spatial_scale, sampling_ratio, aligned, pool_mode, - grad_input_desc.desc(), grad_input_ptr); + TORCH_MLUOP_CHECK(mluOpRoiAlignBackward_v2(handle, grads_desc.desc(), grad_ptr, + rois_desc.desc(), rois_ptr, NULL, NULL, NULL, NULL, + spatial_scale, sampling_ratio, aligned, pool_mode, + grad_input_desc.desc(), grad_input_ptr)); } grad_input.copy_(grad_input_); } diff --git a/mmcv/ops/csrc/pytorch/mlu/roi_align_rotated_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/roi_align_rotated_mlu.cpp index 7cf059cd51..fc05283c97 100644 --- a/mmcv/ops/csrc/pytorch/mlu/roi_align_rotated_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/roi_align_rotated_mlu.cpp @@ -40,10 +40,10 @@ void ROIAlignRotatedForwardMLUKernelLauncher(Tensor input, Tensor rois, // get compute handle auto handle = mluOpGetCurrentHandle(); - mluOpRoiAlignRotatedForward( - handle, input_desc.desc(), input_ptr, rois_desc.desc(), rois_ptr, - pooled_height, pooled_width, sampling_ratio, spatial_scale, aligned, - clockwise, output_desc.desc(), output_ptr); + TORCH_MLUOP_CHECK(mluOpRoiAlignRotatedForward( + handle, input_desc.desc(), input_ptr, rois_desc.desc(), rois_ptr, + pooled_height, pooled_width, sampling_ratio, spatial_scale, aligned, + clockwise, output_desc.desc(), output_ptr); output.copy_(output_contiguous); } @@ -76,10 +76,10 @@ void ROIAlignRotatedBackwardMLUKernelLauncher( // get compute handle auto handle = mluOpGetCurrentHandle(); - mluOpRoiAlignRotatedBackward( - handle, top_grad_desc.desc(), top_grad_ptr, rois_desc.desc(), rois_ptr, - pooled_height, pooled_width, sampling_ratio, spatial_scale, aligned, - clockwise, bottom_grad_desc.desc(), bottom_grad_ptr); + TORCH_MLUOP_CHECK(mluOpRoiAlignRotatedBackward( + handle, top_grad_desc.desc(), top_grad_ptr, rois_desc.desc(), rois_ptr, + pooled_height, pooled_width, sampling_ratio, spatial_scale, aligned, + clockwise, bottom_grad_desc.desc(), bottom_grad_ptr); bottom_grad.copy_(bottom_grad_); } diff --git a/mmcv/ops/csrc/pytorch/mlu/roiaware_pool3d_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/roiaware_pool3d_mlu.cpp index a1c4da4ca3..058be6b27b 100644 --- a/mmcv/ops/csrc/pytorch/mlu/roiaware_pool3d_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/roiaware_pool3d_mlu.cpp @@ -44,9 +44,9 @@ void RoiawarePool3dForwardMLUKernelLauncher( // allocate extra space for workspace size_t workspace_size = 0; - mluOpGetRoiawarePool3dForwardWorkspaceSize( - handle, rois_desc.desc(), pts_desc.desc(), pts_feature_desc.desc(), - &workspace_size); + TORCH_MLUOP_CHECK(mluOpGetRoiawarePool3dForwardWorkspaceSize( + handle, rois_desc.desc(), pts_desc.desc(), pts_feature_desc.desc(), + &workspace_size)); auto workspace = at::empty(workspace_size, rois.options().dtype(at::kByte)); auto workspace_impl = torch_mlu::getMluTensorImpl(workspace); @@ -69,13 +69,13 @@ void RoiawarePool3dForwardMLUKernelLauncher( auto pooled_features_ptr = pooled_features_impl->cnnlMalloc(); CNLOG(INFO) << "Call mluOpRoiawarePool3dForward()."; - mluOpRoiawarePool3dForward( + TORCH_MLUOP_CHECK(mluOpRoiawarePool3dForward( handle, pool_method, boxes_num, pts_num, channels, rois_desc.desc(), rois_ptr, pts_desc.desc(), pts_ptr, pts_feature_desc.desc(), pts_feature_ptr, workspace_ptr, workspace_size, max_pts_each_voxel, out_x, out_y, out_z, argmax_desc.desc(), argmax_ptr, pts_idx_of_voxels_desc.desc(), pts_idx_of_voxels_ptr, - pooled_features_desc.desc(), pooled_features_ptr); + pooled_features_desc.desc(), pooled_features_ptr)); } void roiaware_pool3d_forward_mlu(int boxes_num, int pts_num, int channels, @@ -135,11 +135,11 @@ void RoiawarePool3dBackwardMLUKernelLauncher( auto grad_in_ptr = grad_in_impl->cnnlMalloc(); CNLOG(INFO) << "Call mluOpRoiawarePool3dBackward()."; - mluOpRoiawarePool3dBackward( + TORCH_MLUOP_CHECK(mluOpRoiawarePool3dBackward( handle, pool_method, boxes_num, out_x, out_y, out_z, channels, max_pts_each_voxel, pts_idx_of_voxels_desc.desc(), pts_idx_of_voxels_ptr, argmax_desc.desc(), argmax_ptr, grad_out_desc.desc(), grad_out_ptr, - grad_in_desc.desc(), grad_in_ptr); + grad_in_desc.desc(), grad_in_ptr)); } void roiaware_pool3d_backward_mlu(int boxes_num, int out_x, int out_y, diff --git a/mmcv/ops/csrc/pytorch/mlu/rotated_feature_align_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/rotated_feature_align_mlu.cpp index a827210d2b..9891123f6c 100644 --- a/mmcv/ops/csrc/pytorch/mlu/rotated_feature_align_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/rotated_feature_align_mlu.cpp @@ -40,9 +40,9 @@ void RotatedFeatureAlignForwardMLUKernelLauncher(const Tensor features, // get compute handle auto handle = mluOpGetCurrentHandle(); - mluOpRotatedFeatureAlignForward( + TORCH_MLUOP_CHECK(mluOpRotatedFeatureAlignForward( handle, features_desc.desc(), features_ptr, best_bboxes_desc.desc(), - best_bboxes_ptr, spatial_scale, points, output_desc.desc(), output_ptr); + best_bboxes_ptr, spatial_scale, points, output_desc.desc(), output_ptr)); output.copy_(output_contiguous); } @@ -76,10 +76,10 @@ void RotatedFeatureAlignBackwardMLUKernelLauncher(const Tensor top_grad, // get compute handle auto handle = mluOpGetCurrentHandle(); - mluOpRotatedFeatureAlignBackward(handle, top_grad_desc.desc(), top_grad_ptr, - best_bboxes_desc.desc(), best_bboxes_ptr, - spatial_scale, points, - bottom_grad_desc.desc(), bottom_grad_ptr); + TORCH_MLUOP_CHECK(mluOpRotatedFeatureAlignBackward(handle, top_grad_desc.desc(), top_grad_ptr, + best_bboxes_desc.desc(), best_bboxes_ptr, + spatial_scale, points, + bottom_grad_desc.desc(), bottom_grad_ptr)); bottom_grad.copy_(bottom_grad_); } diff --git a/mmcv/ops/csrc/pytorch/mlu/scatter_points_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/scatter_points_mlu.cpp index cf9713ee22..fa50987220 100644 --- a/mmcv/ops/csrc/pytorch/mlu/scatter_points_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/scatter_points_mlu.cpp @@ -49,20 +49,20 @@ std::vector dynamic_point_to_voxel_forward_mlu( auto handle = mluOpGetCurrentHandle(); size_t workspace_size; - mluOpGetDynamicPointToVoxelForwardWorkspaceSize( - handle, feats_desc.desc(), coors_desc.desc(), &workspace_size); + TORCH_MLUOP_CHECK(mluOpGetDynamicPointToVoxelForwardWorkspaceSize( + handle, feats_desc.desc(), coors_desc.desc(), &workspace_size)); auto workspace_tensor = at::empty(workspace_size, feats.options().dtype(at::kByte)); INITIAL_MLU_PARAM_WITH_TENSOR(workspace_tensor); // launch kernel - mluOpDynamicPointToVoxelForward( + TORCH_MLUOP_CHECK(mluOpDynamicPointToVoxelForward( handle, mlu_reduce_type, feats_desc.desc(), feats_ptr, coors_desc.desc(), coors_ptr, workspace_tensor_ptr, workspace_size, reduced_feats_desc.desc(), reduced_feats_ptr, out_coors_desc.desc(), out_coors_ptr, coors_map_desc.desc(), coors_map_ptr, reduce_count_desc.desc(), reduce_count_ptr, voxel_num_desc.desc(), - voxel_num_ptr); + voxel_num_ptr)); int voxel_num_value = *static_cast(voxel_num.cpu().data_ptr()); TORCH_CHECK(voxel_num_value <= feats.size(0), @@ -124,22 +124,22 @@ void dynamic_point_to_voxel_backward_mlu( auto handle = mluOpGetCurrentHandle(); size_t workspace_size; - mluOpGetDynamicPointToVoxelBackwardWorkspaceSize( + TORCH_MLUOP_CHECK(mluOpGetDynamicPointToVoxelBackwardWorkspaceSize( handle, mlu_reduce_type, grad_feats_desc.desc(), feats_desc.desc(), grad_reduced_feats_desc.desc(), coors_idx_desc.desc(), - reduce_count_desc.desc(), voxel_num_desc.desc(), &workspace_size); + reduce_count_desc.desc(), voxel_num_desc.desc(), &workspace_size)); auto workspace_tensor = at::empty(workspace_size, feats.options().dtype(at::kByte)); INITIAL_MLU_PARAM_WITH_TENSOR(workspace_tensor); // launch kernel - mluOpDynamicPointToVoxelBackward( + TORCH_MLUOP_CHECK(mluOpDynamicPointToVoxelBackward( handle, mlu_reduce_type, grad_reduced_feats_desc.desc(), grad_reduced_feats_ptr, feats_desc.desc(), feats_ptr, reduced_feats_desc.desc(), reduced_feats_ptr, coors_idx_desc.desc(), coors_idx_ptr, reduce_count_desc.desc(), reduce_count_ptr, voxel_num_desc.desc(), voxel_num_ptr, workspace_tensor_ptr, - workspace_size, grad_feats_desc.desc(), grad_feats_ptr); + workspace_size, grad_feats_desc.desc(), grad_feats_ptr)); } std::vector dynamic_point_to_voxel_forward_impl( diff --git a/mmcv/ops/csrc/pytorch/mlu/sparse_conv_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/sparse_conv_mlu.cpp index 19cdb944fe..72ff94d24b 100644 --- a/mmcv/ops/csrc/pytorch/mlu/sparse_conv_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/sparse_conv_mlu.cpp @@ -86,31 +86,31 @@ std::vector GetIndicePairsForwardMLUKernelLauncher( mluOpDataType_t dtype = MLUOP_DTYPE_INT32; std::vector dims; dims = {numAct, coorDim + 1}; - mluOpSetTensorDescriptor(indices_desc.desc(), layout, dtype, dims.size(), - dims.data()); + TORCH_MLUOP_CHECK(mluOpSetTensorDescriptor(indices_desc.desc(), layout, dtype, dims.size(), + dims.data())); dims = {kernelVolume, 2, numAct}; - mluOpSetTensorDescriptor(indicePairs_desc.desc(), layout, dtype, - dims.size(), dims.data()); + TORCH_MLUOP_CHECK(mluOpSetTensorDescriptor(indicePairs_desc.desc(), layout, dtype, + dims.size(), dims.data())); dims = {kernelVolume}; - mluOpSetTensorDescriptor(indiceNum_desc.desc(), layout, dtype, dims.size(), - dims.data()); + TORCH_MLUOP_CHECK(mluOpSetTensorDescriptor(indiceNum_desc.desc(), layout, dtype, dims.size(), + dims.data())); dims = {out_size, coorDim + 1}; - mluOpSetTensorDescriptor(out_indices_desc.desc(), layout, dtype, - dims.size(), dims.data()); + TORCH_MLUOP_CHECK(mluOpSetTensorDescriptor(out_indices_desc.desc(), layout, dtype, + dims.size(), dims.data())); } mluOpSparseConvolutionDescriptor_t sparse_conv_desc; - mluOpCreateSparseConvolutionDescriptor(&sparse_conv_desc); - mluOpSetSparseConvolutionDescriptor( + TORCH_MLUOP_CHECK(mluOpCreateSparseConvolutionDescriptor(&sparse_conv_desc)); + TORCH_MLUOP_CHECK(mluOpSetSparseConvolutionDescriptor( sparse_conv_desc, NDim + 2, batch, padding32.data(), stride32.data(), dilation32.data(), input_space.data(), filter_space.data(), - output_space.data(), sub_m, transpose, 0); + output_space.data(), sub_m, transpose, 0)); auto handle = mluOpGetCurrentHandle(); size_t workspace_size = 0; - mluOpGetIndicePairsWorkspaceSize( + TORCH_MLUOP_CHECK(mluOpGetIndicePairsWorkspaceSize( handle, sparse_conv_desc, indices_desc.desc(), indicePairs_desc.desc(), - out_indices_desc.desc(), indiceNum_desc.desc(), &workspace_size); + out_indices_desc.desc(), indiceNum_desc.desc(), &workspace_size)); auto indice_workspace_size = at::empty(workspace_size, indices.options().dtype(at::kByte)); @@ -127,14 +127,14 @@ std::vector GetIndicePairsForwardMLUKernelLauncher( auto indiceNum_ptr = indiceNum_impl->cnnlMalloc(); auto indice_workspace_ptr = indice_workspace_impl->cnnlMalloc(); - mluOpGetIndicePairs(handle, sparse_conv_desc, indices_desc.desc(), - indices_ptr, indice_workspace_ptr, workspace_size, - indicePairs_desc.desc(), indicePairs_ptr, - out_indices_desc.desc(), out_indices_ptr, - indiceNum_desc.desc(), indiceNum_ptr); + TORCH_MLUOP_CHECK(mluOpGetIndicePairs(handle, sparse_conv_desc, indices_desc.desc(), + indices_ptr, indice_workspace_ptr, workspace_size, + indicePairs_desc.desc(), indicePairs_ptr, + out_indices_desc.desc(), out_indices_ptr, + indiceNum_desc.desc(), indiceNum_ptr)); int num_act_out = 0; - mluOpGetSparseConvolutionNumActOut(sparse_conv_desc, &num_act_out); - mluOpDestroySparseConvolutionDescriptor(sparse_conv_desc); + TORCH_MLUOP_CHECK(mluOpGetSparseConvolutionNumActOut(sparse_conv_desc, &num_act_out)); + TORCH_MLUOP_CHECK(mluOpDestroySparseConvolutionDescriptor(sparse_conv_desc)); if (!sub_m) { return {out_indices.slice(0, 0, num_act_out), indicePairs, indiceNum}; } else { @@ -179,33 +179,33 @@ torch::Tensor IndiceConvForwardMLUKernelLauncher( int dims[8]; // features_desc - mluOpGetTensorDescriptor(features_desc.desc(), &layout, &dtype, &dim, dims); - mluOpSetTensorDescriptor(features_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, - dim, dims); + TORCH_MLUOP_CHECK(mluOpGetTensorDescriptor(features_desc.desc(), &layout, &dtype, &dim, dims)); + TORCH_MLUOP_CHECK(mluOpSetTensorDescriptor(features_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, + dim, dims)); // filters_desc - mluOpGetTensorDescriptor(filters_desc.desc(), &layout, &dtype, &dim, dims); - mluOpSetTensorDescriptor(filters_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, - dim, dims); + TORCH_MLUOP_CHECK(mluOpGetTensorDescriptor(filters_desc.desc(), &layout, &dtype, &dim, dims)); + TORCH_MLUOP_CHECK(mluOpSetTensorDescriptor(filters_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, + dim, dims)); // indice_pairs_desc - mluOpGetTensorDescriptor(indice_pairs_desc.desc(), &layout, &dtype, &dim, - dims); - mluOpSetTensorDescriptor(indice_pairs_desc.desc(), MLUOP_LAYOUT_ARRAY, - dtype, dim, dims); + TORCH_MLUOP_CHECK(mluOpGetTensorDescriptor(indice_pairs_desc.desc(), &layout, &dtype, &dim, + dims)); + TORCH_MLUOP_CHECK(mluOpSetTensorDescriptor(indice_pairs_desc.desc(), MLUOP_LAYOUT_ARRAY, + dtype, dim, dims)); // output_desc - mluOpGetTensorDescriptor(output_desc.desc(), &layout, &dtype, &dim, dims); - mluOpSetTensorDescriptor(output_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, dim, - dims); + TORCH_MLUOP_CHECK(mluOpGetTensorDescriptor(output_desc.desc(), &layout, &dtype, &dim, dims)); + TORCH_MLUOP_CHECK(mluOpSetTensorDescriptor(output_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, dim, + dims)); } auto handle = mluOpGetCurrentHandle(); size_t workspace_size = 0; - mluOpGetIndiceConvolutionForwardWorkspaceSize( + TORCH_MLUOP_CHECK(mluOpGetIndiceConvolutionForwardWorkspaceSize( handle, features_desc.desc(), filters_desc.desc(), indice_pairs_desc.desc(), output_desc.desc(), indice_num, numActOut, - _inverse, _subM, &workspace_size); + _inverse, _subM, &workspace_size)); auto workspace = at::empty(workspace_size, features.options().dtype(at::kByte)); @@ -223,11 +223,11 @@ torch::Tensor IndiceConvForwardMLUKernelLauncher( // outputs auto output_impl = torch_mlu::getMluTensorImpl(output); auto output_ptr = output_impl->cnnlMalloc(); - mluOpIndiceConvolutionForward( + TORCH_MLUOP_CHECK(mluOpIndiceConvolutionForward( handle, features_desc.desc(), features_ptr, filters_desc.desc(), filters_ptr, indice_pairs_desc.desc(), indice_pairs_ptr, indice_num, numActOut, _inverse, _subM, workspace_ptr, workspace_size, - output_desc.desc(), output_ptr); + output_desc.desc(), output_ptr)); return output; } @@ -290,37 +290,37 @@ std::vector IndiceConvBackwardMLUKernelLauncher( int dims[8]; // features_desc - mluOpGetTensorDescriptor(features_desc.desc(), &layout, &dtype, &dim, dims); - mluOpSetTensorDescriptor(features_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, - dim, dims); + TORCH_MLUOP_CHECK(mluOpGetTensorDescriptor(features_desc.desc(), &layout, &dtype, &dim, dims)); + TORCH_MLUOP_CHECK(mluOpSetTensorDescriptor(features_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, + dim, dims)); // filters_desc - mluOpGetTensorDescriptor(filters_desc.desc(), &layout, &dtype, &dim, dims); + TORCH_MLUOP_CHECK(mluOpGetTensorDescriptor(filters_desc.desc(), &layout, &dtype, &dim, dims)); if (dim == 4) { - mluOpSetTensorDescriptor(filters_desc.desc(), MLUOP_LAYOUT_HWCN, dtype, - dim, dims); + TORCH_MLUOP_CHECK(mluOpSetTensorDescriptor(filters_desc.desc(), MLUOP_LAYOUT_HWCN, dtype, + dim, dims)); } else { - mluOpSetTensorDescriptor(filters_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, - dim, dims); + TORCH_MLUOP_CHECK(mluOpSetTensorDescriptor(filters_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, + dim, dims)); } // output_grad_desc - mluOpGetTensorDescriptor(output_grad_desc.desc(), &layout, &dtype, &dim, - dims); - mluOpSetTensorDescriptor(output_grad_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, - dim, dims); + TORCH_MLUOP_CHECK(mluOpGetTensorDescriptor(output_grad_desc.desc(), &layout, &dtype, &dim, + dims)); + TORCH_MLUOP_CHECK(mluOpSetTensorDescriptor(output_grad_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, + dim, dims)); // indice_pairs_desc - mluOpGetTensorDescriptor(indice_pairs_desc.desc(), &layout, &dtype, &dim, - dims); - mluOpSetTensorDescriptor(indice_pairs_desc.desc(), MLUOP_LAYOUT_ARRAY, - dtype, dim, dims); + TORCH_MLUOP_CHECK(mluOpGetTensorDescriptor(indice_pairs_desc.desc(), &layout, &dtype, &dim, + dims)); + TORCH_MLUOP_CHECK(mluOpSetTensorDescriptor(indice_pairs_desc.desc(), MLUOP_LAYOUT_ARRAY, + dtype, dim, dims)); // input_grad_desc - mluOpGetTensorDescriptor(input_grad_desc.desc(), &layout, &dtype, &dim, - dims); - mluOpSetTensorDescriptor(input_grad_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, - dim, dims); + TORCH_MLUOP_CHECK(mluOpGetTensorDescriptor(input_grad_desc.desc(), &layout, &dtype, &dim, + dims)); + TORCH_MLUOP_CHECK(mluOpSetTensorDescriptor(input_grad_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, + dim, dims)); } auto handle = mluOpGetCurrentHandle(); @@ -331,10 +331,10 @@ std::vector IndiceConvBackwardMLUKernelLauncher( &data_workspace_size); size_t filters_workspace_size = 0; - mluOpGetIndiceConvolutionBackwardFilterWorkspaceSize( + TORCH_MLUOP_CHECK(mluOpGetIndiceConvolutionBackwardFilterWorkspaceSize( handle, features_desc.desc(), output_grad_desc.desc(), indice_pairs_desc.desc(), filters_grad_desc.desc(), indice_num, _inverse, - _subM, &filters_workspace_size); + _subM, &filters_workspace_size)); auto indice_convbpdata_workspace = at::empty(data_workspace_size, features.options().dtype(at::kByte)); @@ -365,17 +365,17 @@ std::vector IndiceConvBackwardMLUKernelLauncher( auto filters_grad_impl = torch_mlu::getMluTensorImpl(filters_grad); auto filters_grad_ptr = filters_grad_impl->cnnlMalloc(); - mluOpIndiceConvolutionBackwardData( + TORCH_MLUOP_CHECK(mluOpIndiceConvolutionBackwardData( handle, output_grad_desc.desc(), output_grad_ptr, filters_desc.desc(), filters_ptr, indice_pairs_desc.desc(), indice_pairs_ptr, indice_num, _inverse, _subM, indice_convbpdata_workspace_ptr, data_workspace_size, - input_grad_desc.desc(), input_grad_ptr); + input_grad_desc.desc(), input_grad_ptr)); - mluOpIndiceConvolutionBackwardFilter( + TORCH_MLUOP_CHECK(mluOpIndiceConvolutionBackwardFilter( handle, features_desc.desc(), features_ptr, output_grad_desc.desc(), output_grad_ptr, indice_pairs_desc.desc(), indice_pairs_ptr, indice_num, _inverse, _subM, indice_convbpfilter_workspace_ptr, - filters_workspace_size, filters_grad_desc.desc(), filters_grad_ptr); + filters_workspace_size, filters_grad_desc.desc(), filters_grad_ptr)); std::vector result; result.push_back(input_grad); diff --git a/mmcv/ops/csrc/pytorch/mlu/three_nn_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/three_nn_mlu.cpp index d464802691..0394bdc11a 100644 --- a/mmcv/ops/csrc/pytorch/mlu/three_nn_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/three_nn_mlu.cpp @@ -30,8 +30,8 @@ void ThreeNNMLUKernelLauncher(int b, int n, int m, const Tensor unknown, auto handle = mluOpGetCurrentHandle(); size_t workspace_size = 0; - mluOpGetThreeNNForwardWorkspaceSize(handle, known_desc.desc(), - &workspace_size); + TORCH_MLUOP_CHECK(mluOpGetThreeNNForwardWorkspaceSize(handle, known_desc.desc(), + &workspace_size)); auto known_workspace = at::empty(workspace_size, known.options().dtype(at::kByte)); @@ -46,10 +46,10 @@ void ThreeNNMLUKernelLauncher(int b, int n, int m, const Tensor unknown, auto idx_ptr = idx_impl->cnnlMalloc(); auto workspace_ptr = workspace_impl->cnnlMalloc(); - mluOpThreeNNForward(handle, unknown_desc.desc(), unknown_ptr, - known_desc.desc(), known_ptr, workspace_ptr, - workspace_size, dist2_desc.desc(), dist2_ptr, - idx_desc.desc(), idx_ptr); + TORCH_MLUOP_CHECK(mluOpThreeNNForward(handle, unknown_desc.desc(), unknown_ptr, + known_desc.desc(), known_ptr, workspace_ptr, + workspace_size, dist2_desc.desc(), dist2_ptr, + idx_desc.desc(), idx_ptr)); } void three_nn_forward_mlu(int b, int n, int m, const Tensor unknown, diff --git a/mmcv/ops/csrc/pytorch/mlu/voxelization_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/voxelization_mlu.cpp index 2ffd751ade..24226261b8 100644 --- a/mmcv/ops/csrc/pytorch/mlu/voxelization_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/voxelization_mlu.cpp @@ -53,23 +53,23 @@ int HardVoxelizeForwardMLUKernelLauncher( size_t workspace_size; auto handle = mluOpGetCurrentHandle(); - mluOpGetVoxelizationWorkspaceSize( + TORCH_MLUOP_CHECK(mluOpGetVoxelizationWorkspaceSize( handle, points_desc.desc(), voxel_size_tensor_desc.desc(), coors_range_tensor_desc.desc(), max_points, max_voxels, NDim, true, voxels_desc.desc(), coors_desc.desc(), num_points_per_voxel_desc.desc(), - voxel_num_tensor_desc.desc(), &workspace_size); + voxel_num_tensor_desc.desc(), &workspace_size)); auto workspace_tensor = at::empty(workspace_size, points.options().dtype(at::kByte)); INITIAL_MLU_PARAM_WITH_TENSOR(workspace_tensor); - mluOpVoxelization(handle, points_desc.desc(), points_ptr, + TORCH_MLUOP_CHECK(mluOpVoxelization(handle, points_desc.desc(), points_ptr, voxel_size_tensor_desc.desc(), voxel_size_tensor_ptr, coors_range_tensor_desc.desc(), coors_range_tensor_ptr, max_points, max_voxels, NDim, true, workspace_tensor_ptr, workspace_size, voxels_desc.desc(), voxels_ptr, coors_desc.desc(), coors_ptr, num_points_per_voxel_desc.desc(), num_points_per_voxel_ptr, - voxel_num_tensor_desc.desc(), voxel_num_tensor_ptr); + voxel_num_tensor_desc.desc(), voxel_num_tensor_ptr)); auto voxel_num_cpu = voxel_num_tensor.to(at::kCPU); int voxel_num_int = voxel_num_cpu.data_ptr()[0]; return voxel_num_int;