Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add the support of BoxIouRotated op for ascend device #2842

Merged
merged 1 commit into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/en/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ We implement common ops used in detection, segmentation, etc.
| BallQuery | ||| | |
| BBoxOverlaps | |||||
| BorderAlign | || | | |
| BoxIouRotated |||| | |
| BoxIouRotated |||| | |
| BoxIouQuadri ||| | | |
| CARAFE | ||| | |
| ChamferDistance | || | | |
Expand Down
2 changes: 1 addition & 1 deletion docs/zh_cn/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| BallQuery | ||| | |
| BBoxOverlaps | |||||
| BorderAlign | || | | |
| BoxIouRotated |||| | |
| BoxIouRotated |||| | |
| BoxIouQuadri ||| | | |
| CARAFE | ||| | |
| ChamferDistance | || | | |
Expand Down
5 changes: 5 additions & 0 deletions mmcv/ops/box_iou_rotated.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ def box_iou_rotated(bboxes1: torch.Tensor,
flip_mat[-1] = -1
bboxes1 = bboxes1 * flip_mat
bboxes2 = bboxes2 * flip_mat
if bboxes1.device.type == 'npu':
scale_mat = bboxes1.new_ones(bboxes1.shape[-1])
scale_mat[-1] = 1.0 / 0.01745329252
bboxes1 = bboxes1 * scale_mat
bboxes2 = bboxes2 * scale_mat
bboxes1 = bboxes1.contiguous()
bboxes2 = bboxes2.contiguous()
ext_module.box_iou_rotated(
Expand Down
47 changes: 47 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/box_iou_rotated_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;
using namespace std;

void box_iou_rotated_impl(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned);

void box_iou_rotated_npu(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned) {
at::Tensor boxes = at::ones_like(boxes1);
at::Tensor query_boxes = at::ones_like(boxes2);
boxes = boxes1.transpose(0, 1).unsqueeze(0);
query_boxes = boxes2.transpose(0, 1).unsqueeze(0);

bool is_trans = false;
string modeStr = "iou";
if (mode_flag == 1) {
modeStr = "iof";
}
bool is_cross = true;
if (aligned) {
is_cross = false;
}
float v_threshold = 0;
float e_threshold = 0;

OpCommand cmd;
cmd.Name("RotatedIou")
.Input(boxes)
.Input(query_boxes)
.Output(ious)
.Attr("trans", is_trans)
.Attr("mode", modeStr)
.Attr("is_cross", is_cross)
.Attr("v_threshold", v_threshold)
.Attr("e_threshold", e_threshold)
.Run();

if (is_cross) {
ious = ious.view({boxes1.size(0), boxes2.size(0)});
} else {
ious = ious.view({boxes1.size(0), 1});
}
}

REGISTER_NPU_IMPL(box_iou_rotated_impl, box_iou_rotated_npu);
14 changes: 11 additions & 3 deletions tests/test_ops/test_box_iou_rotated.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

from mmcv.ops import box_iou_rotated
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE


class TestBoxIoURotated:
Expand Down Expand Up @@ -54,7 +54,11 @@ def test_box_iou_rotated_cpu(self):
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
not IS_MLU_AVAILABLE, reason='requires MLU support')),
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
def test_box_iou_rotated(self, device):
np_boxes1 = np.asarray(
Expand Down Expand Up @@ -137,7 +141,11 @@ def test_box_iou_rotated_iof_cpu(self):
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
not IS_MLU_AVAILABLE, reason='requires MLU support')),
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
def test_box_iou_rotated_iof(self, device):
np_boxes1 = np.asarray(
Expand Down