Skip to content

Commit

Permalink
Bump version to 2.0.1 (#2831)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhouzaida authored and yz87 committed Jul 5, 2023
1 parent 3df6414 commit f4772e3
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 4 deletions.
2 changes: 1 addition & 1 deletion docs/en/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ We implement common ops used in detection, segmentation, etc.
| BallQuery | ||| | |
| BBoxOverlaps | |||||
| BorderAlign | || | | |
| BoxIouRotated |||| | |
| BoxIouRotated |||| | |
| BoxIouQuadri ||| | | |
| CARAFE | ||| | |
| ChamferDistance | || | | |
Expand Down
2 changes: 1 addition & 1 deletion docs/zh_cn/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| BallQuery | ||| | |
| BBoxOverlaps | |||||
| BorderAlign | || | | |
| BoxIouRotated |||| | |
| BoxIouRotated |||| | |
| BoxIouQuadri ||| | | |
| CARAFE | ||| | |
| ChamferDistance | || | | |
Expand Down
5 changes: 5 additions & 0 deletions mmcv/ops/box_iou_rotated.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ def box_iou_rotated(bboxes1: torch.Tensor,
flip_mat[-1] = -1
bboxes1 = bboxes1 * flip_mat
bboxes2 = bboxes2 * flip_mat
if bboxes1.device.type == 'npu':
scale_mat = bboxes1.new_ones(bboxes1.shape[-1])
scale_mat[-1] = 1.0 / 0.01745329252
bboxes1 = bboxes1 * scale_mat
bboxes2 = bboxes2 * scale_mat
bboxes1 = bboxes1.contiguous()
bboxes2 = bboxes2.contiguous()
ext_module.box_iou_rotated(
Expand Down
47 changes: 47 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/box_iou_rotated_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;
using namespace std;

void box_iou_rotated_impl(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned);

void box_iou_rotated_npu(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned) {
at::Tensor boxes = at::ones_like(boxes1);
at::Tensor query_boxes = at::ones_like(boxes2);
boxes = boxes1.transpose(0, 1).unsqueeze(0);
query_boxes = boxes2.transpose(0, 1).unsqueeze(0);

bool is_trans = false;
string modeStr = "iou";
if (mode_flag == 1) {
modeStr = "iof";
}
bool is_cross = true;
if (aligned) {
is_cross = false;
}
float v_threshold = 0;
float e_threshold = 0;

OpCommand cmd;
cmd.Name("RotatedIou")
.Input(boxes)
.Input(query_boxes)
.Output(ious)
.Attr("trans", is_trans)
.Attr("mode", modeStr)
.Attr("is_cross", is_cross)
.Attr("v_threshold", v_threshold)
.Attr("e_threshold", e_threshold)
.Run();

if (is_cross) {
ious = ious.view({boxes1.size(0), boxes2.size(0)});
} else {
ious = ious.view({boxes1.size(0), 1});
}
}

REGISTER_NPU_IMPL(box_iou_rotated_impl, box_iou_rotated_npu);
2 changes: 1 addition & 1 deletion mmcv/version.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
__version__ = '2.0.0'
__version__ = '2.0.1'


def parse_version_info(version_str: str, length: int = 4) -> tuple:
Expand Down
81 changes: 80 additions & 1 deletion tests/test_ops/test_box_iou_rotated.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

from mmcv.ops import box_iou_rotated
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE


class TestBoxIoURotated:
Expand Down Expand Up @@ -172,6 +172,85 @@ def test_box_iou_rotated_iof(self, device):
ious = box_iou_rotated(boxes1, boxes2, mode='iof', clockwise=False)
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4)

ious = box_iou_rotated(
boxes1, boxes2, mode='iof', aligned=True, clockwise=False)
assert np.allclose(ious.cpu().numpy(),
np_expect_ious_aligned, atol=1e-4)

@pytest.mark.skipif(not IS_NPU_AVAILABLE, reason='requires NPU support')
def test_box_iou_rotated_npu(self):
from mmcv.ops import box_iou_rotated
np_boxes1 = np.asarray(
[[1.0, 1.0, 3.0, 4.0, 0.5], [2.0, 2.0, 3.0, 4.0, 0.6],
[7.0, 7.0, 8.0, 8.0, 0.4]],
dtype=np.float32)
np_boxes2 = np.asarray(
[[0.0, 2.0, 2.0, 5.0, 0.3], [2.0, 1.0, 3.0, 3.0, 0.5],
[5.0, 5.0, 6.0, 7.0, 0.4]],
dtype=np.float32)
np_expect_ious = np.asarray(
[[0.3708, 0.4351, 0.0000], [0.1104, 0.4487, 0.0424],
[0.0000, 0.0000, 0.3622]],
dtype=np.float32)
np_expect_ious_aligned = np.asarray([0.3708, 0.4487, 0.3622],
dtype=np.float32)

boxes1 = torch.from_numpy(np_boxes1).npu()
boxes2 = torch.from_numpy(np_boxes2).npu()

# test cw angle definition
ious = box_iou_rotated(boxes1, boxes2)
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4)

ious = box_iou_rotated(boxes1, boxes2, aligned=True)
assert np.allclose(
ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4)

# test ccw angle definition
boxes1[..., -1] *= -1
boxes2[..., -1] *= -1
ious = box_iou_rotated(boxes1, boxes2, clockwise=False)
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4)

ious = box_iou_rotated(boxes1, boxes2, aligned=True, clockwise=False)
assert np.allclose(
ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4)

@pytest.mark.skipif(not IS_NPU_AVAILABLE, reason='requires NPU support')
def test_box_iou_rotated_iof_npu(self):
from mmcv.ops import box_iou_rotated
np_boxes1 = np.asarray(
[[1.0, 1.0, 3.0, 4.0, 0.5], [2.0, 2.0, 3.0, 4.0, 0.6],
[7.0, 7.0, 8.0, 8.0, 0.4]],
dtype=np.float32)
np_boxes2 = np.asarray(
[[0.0, 2.0, 2.0, 5.0, 0.3], [2.0, 1.0, 3.0, 3.0, 0.5],
[5.0, 5.0, 6.0, 7.0, 0.4]],
dtype=np.float32)
np_expect_ious = np.asarray(
[[0.4959, 0.5306, 0.0000], [0.1823, 0.5420, 0.1832],
[0.0000, 0.0000, 0.4404]],
dtype=np.float32)
np_expect_ious_aligned = np.asarray([0.4959, 0.5420, 0.4404],
dtype=np.float32)

boxes1 = torch.from_numpy(np_boxes1).npu()
boxes2 = torch.from_numpy(np_boxes2).npu()

# test cw angle definition
ious = box_iou_rotated(boxes1, boxes2, mode='iof')
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4)

ious = box_iou_rotated(boxes1, boxes2, mode='iof', aligned=True)
assert np.allclose(
ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4)

# test ccw angle definition
boxes1[..., -1] *= -1
boxes2[..., -1] *= -1
ious = box_iou_rotated(boxes1, boxes2, mode='iof', clockwise=False)
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4)

ious = box_iou_rotated(
boxes1, boxes2, mode='iof', aligned=True, clockwise=False)
assert np.allclose(
Expand Down

0 comments on commit f4772e3

Please sign in to comment.