Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Add torch mluops check before calling mluOpsxxx interface #2871

Merged
merged 2 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mmcv/ops/csrc/pytorch/mlu/ball_query_mlu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ void ball_query_forward_mlu(int b, int n, int m, float min_radius,
auto idx_ptr = idx_impl->cnnlMalloc();

auto handle = mluOpGetCurrentHandle();
mluOpBallQuery(handle, new_xyz_desc.desc(), new_xyz_ptr, xyz_desc.desc(),
xyz_ptr, min_radius, max_radius, nsample, idx_desc.desc(),
idx_ptr);
TORCH_MLUOP_CHECK(mluOpBallQuery(handle, new_xyz_desc.desc(), new_xyz_ptr, xyz_desc.desc(),
xyz_ptr, min_radius, max_radius, nsample, idx_desc.desc(),
idx_ptr));
}

void ball_query_forward_impl(int b, int n, int m, float min_radius,
Expand Down
6 changes: 3 additions & 3 deletions mmcv/ops/csrc/pytorch/mlu/box_iou_rotated.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ void BoxIouRotatedMLUKernelLauncher(const Tensor boxes1, const Tensor boxes2,
auto ious_ptr = ious_impl->cnnlMalloc();

CNLOG(INFO) << "Call mluOpBoxIouRotated().";
mluOpBoxIouRotated(handle, mode_flag, aligned, boxes1_desc.desc(), boxes1_ptr,
boxes2_desc.desc(), boxes2_ptr, ious_desc.desc(),
ious_ptr);
TORCH_MLUOP_CHECK(mluOpBoxIouRotated(handle, mode_flag, aligned, boxes1_desc.desc(), boxes1_ptr,
boxes2_desc.desc(), boxes2_ptr, ious_desc.desc(),
ious_ptr));
}

void box_iou_rotated_mlu(const Tensor boxes1, const Tensor boxes2, Tensor ious,
Expand Down
30 changes: 15 additions & 15 deletions mmcv/ops/csrc/pytorch/mlu/carafe_mlu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,15 @@ void CARAFEForwardMLUKernelLauncher(const Tensor input, const Tensor mask,
// set op descriptor
auto handle = mluOpGetCurrentHandle();
mluOpCarafeDescriptor_t carafe_desc;
mluOpCreateCarafeDescriptor(&carafe_desc);
mluOpSetCarafeDescriptor(carafe_desc, input.dim(), kernel_size, group_size,
scale_factor);
TORCH_MLUOP_CHECK(mluOpCreateCarafeDescriptor(&carafe_desc));
TORCH_MLUOP_CHECK(mluOpSetCarafeDescriptor(carafe_desc, input.dim(), kernel_size, group_size,
scale_factor));
// launch kernel
mluOpCarafeForward(handle, carafe_desc, input_desc.desc(), input_ptr,
mask_desc.desc(), mask_ptr, output_desc.desc(),
output_ptr);
TORCH_MLUOP_CHECK(mluOpCarafeForward(handle, carafe_desc, input_desc.desc(), input_ptr,
mask_desc.desc(), mask_ptr, output_desc.desc(),
output_ptr));
// destroy op descriptor
mluOpDestroyCarafeDescriptor(carafe_desc);
TORCH_MLUOP_CHECK(mluOpDestroyCarafeDescriptor(carafe_desc));

// copy output from NHWC back into NCHW
rinput.copy_(rinput_);
Expand Down Expand Up @@ -159,16 +159,16 @@ void CARAFEBackwardMLUKernelLauncher(
// set op descriptor
auto handle = mluOpGetCurrentHandle();
mluOpCarafeDescriptor_t carafe_desc;
mluOpCreateCarafeDescriptor(&carafe_desc);
mluOpSetCarafeDescriptor(carafe_desc, grad_output.dim(), kernel_size,
group_size, scale_factor);
TORCH_MLUOP_CHECK(mluOpCreateCarafeDescriptor(&carafe_desc));
TORCH_MLUOP_CHECK(mluOpSetCarafeDescriptor(carafe_desc, grad_output.dim(), kernel_size,
group_size, scale_factor));
// launch kernel
mluOpCarafeBackward(handle, carafe_desc, input_desc.desc(), input_ptr,
mask_desc.desc(), mask_ptr, grad_output_desc.desc(),
grad_output_ptr, grad_input_desc.desc(), grad_input_ptr,
grad_mask_desc.desc(), grad_mask_ptr);
TORCH_MLUOP_CHECK(mluOpCarafeBackward(handle, carafe_desc, input_desc.desc(), input_ptr,
mask_desc.desc(), mask_ptr, grad_output_desc.desc(),
grad_output_ptr, grad_input_desc.desc(), grad_input_ptr,
grad_mask_desc.desc(), grad_mask_ptr));
// destroy op descriptor
mluOpDestroyCarafeDescriptor(carafe_desc);
TORCH_MLUOP_CHECK(mluOpDestroyCarafeDescriptor(carafe_desc));

// copy output from NHWC back into NCHW
grad_input.copy_(rgrad_input_);
Expand Down
20 changes: 10 additions & 10 deletions mmcv/ops/csrc/pytorch/mlu/deform_roi_pool_mlu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ void DeformRoIPoolForwardMLUKernelLauncher(Tensor input, Tensor rois,

// get compute handle
auto handle = mluOpGetCurrentHandle();
mluOpDeformRoiPoolForward(
handle, input_desc.desc(), input_ptr, rois_desc.desc(), rois_ptr,
offset_real_desc, offset_ptr, pooled_height, pooled_width, spatial_scale,
sampling_ratio, gamma, output_desc.desc(), output_ptr);
TORCH_MLUOP_CHECK(mluOpDeformRoiPoolForward(
handle, input_desc.desc(), input_ptr, rois_desc.desc(), rois_ptr,
offset_real_desc, offset_ptr, pooled_height, pooled_width, spatial_scale,
sampling_ratio, gamma, output_desc.desc(), output_ptr));

output.copy_(output_contiguous);
}
Expand Down Expand Up @@ -113,12 +113,12 @@ void DeformRoIPoolBackwardMLUKernelLauncher(

// get compute handle
auto handle = mluOpGetCurrentHandle();
mluOpDeformRoiPoolBackward(
handle, grad_output_desc.desc(), grad_output_ptr, input_desc.desc(),
input_ptr, rois_desc.desc(), rois_ptr, offset_real_desc, offset_ptr,
pooled_height, pooled_width, spatial_scale, sampling_ratio, gamma,
grad_input_desc.desc(), grad_input_ptr, grad_offset_real_desc,
grad_offset_ptr);
TORCH_MLUOP_CHECK(mluOpDeformRoiPoolBackward(
handle, grad_output_desc.desc(), grad_output_ptr, input_desc.desc(),
input_ptr, rois_desc.desc(), rois_ptr, offset_real_desc, offset_ptr,
pooled_height, pooled_width, spatial_scale, sampling_ratio, gamma,
grad_input_desc.desc(), grad_input_ptr, grad_offset_real_desc,
grad_offset_ptr));
grad_input.copy_(grad_input_);
}

Expand Down
6 changes: 3 additions & 3 deletions mmcv/ops/csrc/pytorch/mlu/diff_iou_rotated_mlu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ Tensor diff_iou_rotated_sort_vertices_forward_mlu(Tensor vertices, Tensor mask,
auto handle = mluOpGetCurrentHandle();

// launch kernel
mluOpDiffIouRotatedSortVerticesForward(
handle, vertices_desc.desc(), vertices_ptr, mask_desc.desc(), mask_ptr,
num_valid_desc.desc(), num_valid_ptr, idx_desc.desc(), idx_ptr);
TORCH_MLUOP_CHECK(mluOpDiffIouRotatedSortVerticesForward(
handle, vertices_desc.desc(), vertices_ptr, mask_desc.desc(), mask_ptr,
num_valid_desc.desc(), num_valid_ptr, idx_desc.desc(), idx_ptr));
return idx;
}

Expand Down
20 changes: 10 additions & 10 deletions mmcv/ops/csrc/pytorch/mlu/iou3d_mlu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ void IoU3DNMS3DMLUKernelLauncher(Tensor boxes, Tensor &keep, Tensor &keep_num,
// workspace
size_t workspace_size = 0;
auto handle = mluOpGetCurrentHandle();
mluOpGetNmsWorkspaceSize(handle, boxes_desc.desc(), NULL, &workspace_size);
TORCH_MLUOP_CHECK(mluOpGetNmsWorkspaceSize(handle, boxes_desc.desc(), NULL, &workspace_size));
auto workspace = at::empty(workspace_size, boxes.options().dtype(at::kByte));

// get compute queue
Expand All @@ -56,16 +56,16 @@ void IoU3DNMS3DMLUKernelLauncher(Tensor boxes, Tensor &keep, Tensor &keep_num,
const int max_output_size = input_box_num;
const float offset = 0.0;

mluOpCreateNmsDescriptor(&nms_desc);
mluOpSetNmsDescriptor(nms_desc, box_mode, output_mode, algo, method_mode,
iou_threshold, soft_nms_sigma, max_output_size,
confidence_threshold, offset, input_layout,
pad_to_max_output_size);
TORCH_MLUOP_CHECK(mluOpCreateNmsDescriptor(&nms_desc));
TORCH_MLUOP_CHECK(mluOpSetNmsDescriptor(nms_desc, box_mode, output_mode, algo, method_mode,
iou_threshold, soft_nms_sigma, max_output_size,
confidence_threshold, offset, input_layout,
pad_to_max_output_size));

mluOpNms(handle, nms_desc, boxes_desc.desc(), boxes_ptr, NULL, NULL,
workspace_ptr, workspace_size, output_desc.desc(), output_ptr,
output_size_ptr);
mluOpDestroyNmsDescriptor(nms_desc);
TORCH_MLUOP_CHECK(mluOpNms(handle, nms_desc, boxes_desc.desc(), boxes_ptr, NULL, NULL,
workspace_ptr, workspace_size, output_desc.desc(), output_ptr,
output_size_ptr));
TORCH_MLUOP_CHECK(mluOpDestroyNmsDescriptor(nms_desc));
}

void iou3d_nms3d_forward_mlu(const Tensor boxes, Tensor &keep, Tensor &keep_num,
Expand Down
2 changes: 1 addition & 1 deletion mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ void MluOpTensorDescriptor::set_desc(const at::Tensor& t,
mluOpDataType_t dtype,
std::vector<int>& dims) {
int dimNb = dims.size();
mluOpSetTensorDescriptor(desc_, layout, dtype, dimNb, dims.data());
TORCH_MLUOP_CHECK(mluOpSetTensorDescriptor(desc_, layout, dtype, dimNb, dims.data()));
}

// Handles
Expand Down
10 changes: 5 additions & 5 deletions mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ mluOpReduceMode_t getMluOpReduceMode(const reduce_t reduce_type);

class MluOpTensorDescriptor {
public:
MluOpTensorDescriptor() { mluOpCreateTensorDescriptor(&desc_); };
~MluOpTensorDescriptor() { mluOpDestroyTensorDescriptor(desc_); }
MluOpTensorDescriptor() { TORCH_MLUOP_CHECK(mluOpCreateTensorDescriptor(&desc_)); };
~MluOpTensorDescriptor() { TORCH_MLUOP_CHECK(mluOpDestroyTensorDescriptor(desc_)); }

void set(at::Tensor);
void set_with_layout(at::Tensor, mluOpTensorLayout_t layout);
Expand All @@ -71,14 +71,14 @@ mluOpHandle_t mluOpGetCurrentHandle(c10::DeviceIndex device_index = -1);

class MluOpHandle {
public:
MluOpHandle() : handle(nullptr) { mluOpCreate(&handle); }
MluOpHandle() : handle(nullptr) { TORCH_MLUOP_CHECK(mluOpCreate(&handle)); }
~MluOpHandle() {
if (handle) {
mluOpDestroy(handle);
TORCH_MLUOP_CHECK(mluOpDestroy(handle));
handle = nullptr;
}
}
void setQueue(cnrtQueue_t queue) { mluOpSetQueue(handle, queue); }
void setQueue(cnrtQueue_t queue) { TORCH_MLUOP_CHECK(mluOpSetQueue(handle, queue)); }
mluOpHandle_t handle;
};

Expand Down
12 changes: 6 additions & 6 deletions mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ Tensor MsDeformAttnForwardLauncher(const Tensor& value,
INITIAL_MLU_PARAM_WITH_TENSOR(sampling_loc);
INITIAL_MLU_PARAM_WITH_TENSOR(attn_weight);

mluOpMsDeformAttnForward(
handle, value_desc.desc(), value_ptr, spatial_shapes_int_desc.desc(),
spatial_shapes_int_ptr, level_start_index_int_desc.desc(),
level_start_index_int_ptr, sampling_loc_desc.desc(), sampling_loc_ptr,
attn_weight_desc.desc(), attn_weight_ptr, im2col_step, output_desc.desc(),
output_ptr);
TORCH_MLUOP_CHECK(mluOpMsDeformAttnForward(
handle, value_desc.desc(), value_ptr, spatial_shapes_int_desc.desc(),
spatial_shapes_int_ptr, level_start_index_int_desc.desc(),
level_start_index_int_ptr, sampling_loc_desc.desc(), sampling_loc_ptr,
attn_weight_desc.desc(), attn_weight_ptr, im2col_step, output_desc.desc(),
output_ptr));

output = output.view({batch_size, num_queries, num_heads * channels});
return output;
Expand Down
22 changes: 11 additions & 11 deletions mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ Tensor NMSMLUKernelLauncher(Tensor boxes, Tensor scores, float iou_threshold,
// workspace
size_t workspace_size = 0;
auto handle = mluOpGetCurrentHandle();
mluOpGetNmsWorkspaceSize(handle, boxes_desc.desc(), scores_desc.desc(),
&workspace_size);
TORCH_MLUOP_CHECK(mluOpGetNmsWorkspaceSize(handle, boxes_desc.desc(), scores_desc.desc(),
&workspace_size));
auto workspace = at::empty(workspace_size, boxes.options().dtype(at::kByte));

// get compute queue
Expand All @@ -62,16 +62,16 @@ Tensor NMSMLUKernelLauncher(Tensor boxes, Tensor scores, float iou_threshold,
const bool pad_to_max_output_size = false;
const int max_output_size = max_output_boxes;

mluOpCreateNmsDescriptor(&nms_desc);
mluOpSetNmsDescriptor(nms_desc, box_mode, output_mode, algo, method_mode,
iou_threshold, soft_nms_sigma, max_output_size,
confidence_threshold, (float)offset, input_layout,
pad_to_max_output_size);
TORCH_MLUOP_CHECK(mluOpCreateNmsDescriptor(&nms_desc));
TORCH_MLUOP_CHECK(mluOpSetNmsDescriptor(nms_desc, box_mode, output_mode, algo, method_mode,
iou_threshold, soft_nms_sigma, max_output_size,
confidence_threshold, (float)offset, input_layout,
pad_to_max_output_size));

mluOpNms(handle, nms_desc, boxes_desc.desc(), boxes_ptr, scores_desc.desc(),
scores_ptr, workspace_ptr, workspace_size, output_desc.desc(),
output_ptr, output_size_ptr);
mluOpDestroyNmsDescriptor(nms_desc);
TORCH_MLUOP_CHECK(mluOpNms(handle, nms_desc, boxes_desc.desc(), boxes_ptr, scores_desc.desc(),
scores_ptr, workspace_ptr, workspace_size, output_desc.desc(),
output_ptr, output_size_ptr));
TORCH_MLUOP_CHECK(mluOpDestroyNmsDescriptor(nms_desc));
int output_num = *static_cast<int *>(output_size.cpu().data_ptr());
auto ret = output.to(boxes.options().dtype(at::kLong));
return ret.slice(0, 0, output_num);
Expand Down
8 changes: 4 additions & 4 deletions mmcv/ops/csrc/pytorch/mlu/nms_rotated_mlu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Tensor nms_rotated_mlu(Tensor boxes, Tensor scores, float iou_threshold) {
// workspace
size_t workspace_size = 0;
auto handle = mluOpGetCurrentHandle();
mluOpGetNmsRotatedWorkspaceSize(handle, boxes_desc.desc(), &workspace_size);
TORCH_MLUOP_CHECK(mluOpGetNmsRotatedWorkspaceSize(handle, boxes_desc.desc(), &workspace_size));
auto workspace = at::empty(workspace_size, boxes.options().dtype(at::kByte));

auto boxes_impl = torch_mlu::getMluTensorImpl(boxes_);
Expand All @@ -44,9 +44,9 @@ Tensor nms_rotated_mlu(Tensor boxes, Tensor scores, float iou_threshold) {
auto output_size_impl = torch_mlu::getMluTensorImpl(output_size);
auto output_size_ptr = output_size_impl->cnnlMalloc();

mluOpNmsRotated(handle, iou_threshold, boxes_desc.desc(), boxes_ptr,
scores_desc.desc(), scores_ptr, workspace_ptr, workspace_size,
output_desc.desc(), output_ptr, (int *)output_size_ptr);
TORCH_MLUOP_CHECK(mluOpNmsRotated(handle, iou_threshold, boxes_desc.desc(), boxes_ptr,
scores_desc.desc(), scores_ptr, workspace_ptr, workspace_size,
output_desc.desc(), output_ptr, (int *)output_size_ptr));
int output_num = *static_cast<int *>(output_size.cpu().data_ptr());
auto ret = output.to(boxes.options().dtype(at::kLong));
return ret.slice(0, 0, output_num);
Expand Down
8 changes: 4 additions & 4 deletions mmcv/ops/csrc/pytorch/mlu/psamask_mlu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ void PSAMaskForwardMLUKernelLauncher(const int psa_type, const Tensor x,
auto y_impl = torch_mlu::getMluTensorImpl(y_tmp);
auto y_ptr = y_impl->cnnlMalloc();

mluOpPsamaskForward(handle, psa_type, x_desc.desc(), x_ptr, h_mask, w_mask,
y_desc.desc(), y_ptr);
TORCH_MLUOP_CHECK(mluOpPsamaskForward(handle, psa_type, x_desc.desc(), x_ptr, h_mask, w_mask,
y_desc.desc(), y_ptr));

y.copy_(y_tmp);
}
Expand Down Expand Up @@ -67,8 +67,8 @@ void PSAMaskBackwardMLUKernelLauncher(const int psa_type, const Tensor dy,
auto dy_impl = torch_mlu::getMluTensorImpl(dy_tensor);
auto dy_ptr = dy_impl->cnnlMalloc();

mluOpPsamaskBackward(handle, psa_type, dy_desc.desc(), dy_ptr, h_mask, w_mask,
dx_tmp_desc.desc(), dx_ptr);
TORCH_MLUOP_CHECK(mluOpPsamaskBackward(handle, psa_type, dy_desc.desc(), dy_ptr, h_mask, w_mask,
dx_tmp_desc.desc(), dx_ptr));

dx.copy_(dx_tmp);
}
Expand Down
42 changes: 21 additions & 21 deletions mmcv/ops/csrc/pytorch/mlu/roi_align_mlu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ void ROIAlignForwardMLUKernelLauncher(Tensor input, Tensor rois, Tensor output,
auto output_ptr = output_impl->cnnlMalloc();

mluOpRoiAlignForwardDescriptor_t roialign_desc;
mluOpCreateRoiAlignForwardDescriptor(&roialign_desc);
mluOpSetRoiAlignForwardDescriptor_v2(roialign_desc, aligned_height,
aligned_width, sampling_ratio,
spatial_scale, pool_mode, aligned);
TORCH_MLUOP_CHECK(mluOpCreateRoiAlignForwardDescriptor(&roialign_desc));
TORCH_MLUOP_CHECK(mluOpSetRoiAlignForwardDescriptor_v2(roialign_desc, aligned_height,
aligned_width, sampling_ratio,
spatial_scale, pool_mode, aligned));

auto handle = mluOpGetCurrentHandle();
if (pool_mode == 0) {
Expand All @@ -65,18 +65,18 @@ void ROIAlignForwardMLUKernelLauncher(Tensor input, Tensor rois, Tensor output,
auto argmax_y_ptr = argmax_y_impl->cnnlMalloc();
argmax_y_desc.set_with_layout(argmax_x_contiguous, MLUOP_LAYOUT_NHWC);
argmax_x_desc.set_with_layout(argmax_x_contiguous, MLUOP_LAYOUT_NHWC);
mluOpRoiAlignForward_v2(handle, roialign_desc, input_desc.desc(), self_ptr,
rois_desc.desc(), rois_ptr, output_desc.desc(),
output_ptr, argmax_x_desc.desc(), argmax_x_ptr,
argmax_y_desc.desc(), argmax_y_ptr);
TORCH_MLUOP_CHECK(mluOpRoiAlignForward_v2(handle, roialign_desc, input_desc.desc(), self_ptr,
rois_desc.desc(), rois_ptr, output_desc.desc(),
output_ptr, argmax_x_desc.desc(), argmax_x_ptr,
argmax_y_desc.desc(), argmax_y_ptr);
argmax_x.copy_(argmax_x_contiguous);
argmax_y.copy_(argmax_y_contiguous);
} else {
mluOpRoiAlignForward_v2(handle, roialign_desc, input_desc.desc(), self_ptr,
rois_desc.desc(), rois_ptr, output_desc.desc(),
output_ptr, NULL, NULL, NULL, NULL);
TORCH_MLUOP_CHECK(mluOpRoiAlignForward_v2(handle, roialign_desc, input_desc.desc(), self_ptr,
rois_desc.desc(), rois_ptr, output_desc.desc(),
output_ptr, NULL, NULL, NULL, NULL);
}
mluOpDestroyRoiAlignForwardDescriptor(roialign_desc);
TORCH_MLUOP_CHECK(mluOpDestroyRoiAlignForwardDescriptor(roialign_desc));
output.copy_(output_contiguous);
}

Expand Down Expand Up @@ -136,16 +136,16 @@ void ROIAlignBackwardMLUKernelLauncher(Tensor grad, Tensor rois,
auto argmax_y_ptr = argmax_y_impl->cnnlMalloc();
argmax_y_desc.set_with_layout(argmax_x_contiguous, MLUOP_LAYOUT_NHWC);
argmax_x_desc.set_with_layout(argmax_x_contiguous, MLUOP_LAYOUT_NHWC);
mluOpRoiAlignBackward_v2(handle, grads_desc.desc(), grad_ptr,
rois_desc.desc(), rois_ptr, argmax_y_desc.desc(),
argmax_x_ptr, argmax_y_desc.desc(), argmax_y_ptr,
spatial_scale, sampling_ratio, aligned, pool_mode,
grad_input_desc.desc(), grad_input_ptr);
TORCH_MLUOP_CHECK(mluOpRoiAlignBackward_v2(handle, grads_desc.desc(), grad_ptr,
rois_desc.desc(), rois_ptr, argmax_y_desc.desc(),
argmax_x_ptr, argmax_y_desc.desc(), argmax_y_ptr,
spatial_scale, sampling_ratio, aligned, pool_mode,
grad_input_desc.desc(), grad_input_ptr));
} else {
mluOpRoiAlignBackward_v2(handle, grads_desc.desc(), grad_ptr,
rois_desc.desc(), rois_ptr, NULL, NULL, NULL, NULL,
spatial_scale, sampling_ratio, aligned, pool_mode,
grad_input_desc.desc(), grad_input_ptr);
TORCH_MLUOP_CHECK(mluOpRoiAlignBackward_v2(handle, grads_desc.desc(), grad_ptr,
rois_desc.desc(), rois_ptr, NULL, NULL, NULL, NULL,
spatial_scale, sampling_ratio, aligned, pool_mode,
grad_input_desc.desc(), grad_input_ptr));
}
grad_input.copy_(grad_input_);
}
Expand Down
Loading