Skip to content

Commit

Permalink
[Feature] Support NmsRotated with cambricon MLU backend (#2643)
Browse files Browse the repository at this point in the history
* [Feature] Support NmsRotated with cambricon MLU backend

* [Feature] remove foolproofs in nms_rotated_mlu.cpp

* [Feature] fix lint in test_nms_rotated.py

* [Feature] fix kMLU not found in nms_rotated.cpp

* [Feature] modify mlu support in nms.py

* [Feature] modify nms_rotated support in ops.md

* [Feature] modify ops/nms.py
  • Loading branch information
liuyuan1-v authored Mar 23, 2023
1 parent e6f434c commit 0d1b224
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 9 deletions.
2 changes: 1 addition & 1 deletion docs/en/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ We implement common ops used in detection, segmentation, etc.
| ModulatedDeformConv2d |||| ||
| MultiScaleDeformableAttn | ||| | |
| NMS |||| ||
| NMSRotated ||| | ||
| NMSRotated ||| | ||
| NMSQuadri ||| | | |
| PixelGroup || | | | |
| PointsInBoxes ||| | | |
Expand Down
2 changes: 1 addition & 1 deletion docs/zh_cn/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| ModulatedDeformConv2d |||| ||
| MultiScaleDeformableAttn | ||| | |
| NMS |||| ||
| NMSRotated ||| | ||
| NMSRotated ||| | ||
| NMSQuadri ||| | | |
| PixelGroup || | | | |
| PointsInBoxes ||| | | |
Expand Down
53 changes: 53 additions & 0 deletions mmcv/ops/csrc/pytorch/mlu/nms_rotated_mlu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*************************************************************************
* Copyright (C) 2021 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "mlu_common_helper.h"

Tensor nms_rotated_mlu(Tensor boxes, Tensor scores, float iou_threshold) {
if (boxes.numel() == 0) {
return at::empty({0}, boxes.options().dtype(at::kLong));
}

int boxes_num = boxes.size(0);
auto boxes_ = torch_mlu::cnnl::ops::cnnl_contiguous(boxes);
auto scores_ = torch_mlu::cnnl::ops::cnnl_contiguous(scores);
auto output = at::empty({boxes_num}, boxes.options().dtype(at::kInt));
auto output_size = at::empty({1}, scores.options().dtype(at::kInt));

MluOpTensorDescriptor boxes_desc, scores_desc, output_desc;
boxes_desc.set(boxes_);
scores_desc.set(scores_);
output_desc.set(output);

// workspace
size_t workspace_size = 0;
auto handle = mluOpGetCurrentHandle();
mluOpGetNmsRotatedWorkspaceSize(handle, boxes_desc.desc(), &workspace_size);
auto workspace = at::empty(workspace_size, boxes.options().dtype(at::kByte));

auto boxes_impl = torch_mlu::getMluTensorImpl(boxes_);
auto boxes_ptr = boxes_impl->cnnlMalloc();
auto scores_impl = torch_mlu::getMluTensorImpl(scores_);
auto scores_ptr = scores_impl->cnnlMalloc();
auto workspace_impl = torch_mlu::getMluTensorImpl(workspace);
auto workspace_ptr = workspace_impl->cnnlMalloc();
auto output_impl = torch_mlu::getMluTensorImpl(output);
auto output_ptr = output_impl->cnnlMalloc();
auto output_size_impl = torch_mlu::getMluTensorImpl(output_size);
auto output_size_ptr = output_size_impl->cnnlMalloc();

mluOpNmsRotated(handle, iou_threshold, boxes_desc.desc(), boxes_ptr,
scores_desc.desc(), scores_ptr, workspace_ptr, workspace_size,
output_desc.desc(), output_ptr, (int *)output_size_ptr);
int output_num = *static_cast<int *>(output_size.cpu().data_ptr());
auto ret = output.to(boxes.options().dtype(at::kLong));
return ret.slice(0, 0, output_num);
}
9 changes: 9 additions & 0 deletions mmcv/ops/csrc/pytorch/nms_rotated.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ Tensor nms_rotated_npu(const Tensor dets, const Tensor scores,
const Tensor labels, const float iou_threshold);
#endif

#ifdef MMCV_WITH_MLU
Tensor nms_rotated_mlu(const Tensor dets, const Tensor scores,
const float iou_threshold);
#endif

// Interface for Python
// inline is needed to prevent multiple function definitions when this header is
// included by different cpps
Expand All @@ -36,6 +41,10 @@ Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order,
return nms_rotated_npu(dets, scores, labels, iou_threshold);
#else
AT_ERROR("Not compiled with NPU support");
#endif
#ifdef MMCV_WITH_MLU
} else if (dets.device().type() == at::kMLU) {
return nms_rotated_mlu(dets, scores, iou_threshold);
#endif
}

Expand Down
9 changes: 5 additions & 4 deletions mmcv/ops/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,11 +458,12 @@ def nms_rotated(dets: Tensor,
input_labels = scores.new_empty(0, dtype=torch.int)
else:
input_labels = labels
if dets.device.type == 'npu':
if dets.device.type in ('npu', 'mlu'):
order = scores.new_empty(0, dtype=torch.long)
coefficient = 57.29578 # 180 / PI
for i in range(dets.size()[0]):
dets_cw[i][4] *= coefficient # radians to angle
if dets.device.type == 'npu':
coefficient = 57.29578 # 180 / PI
for i in range(dets.size()[0]):
dets_cw[i][4] *= coefficient # radians to angle
keep_inds = ext_module.nms_rotated(dets_cw, scores, order, dets_cw,
input_labels, iou_threshold,
multi_label)
Expand Down
14 changes: 11 additions & 3 deletions tests/test_ops/test_nms_rotated.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import torch

from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE


class TestNmsRotated:
Expand All @@ -16,7 +16,11 @@ class TestNmsRotated:
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support'))
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
def test_ml_nms_rotated(self, device):
from mmcv.ops import nms_rotated
Expand Down Expand Up @@ -58,7 +62,11 @@ def test_ml_nms_rotated(self, device):
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support'))
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
def test_nms_rotated(self, device):
from mmcv.ops import nms_rotated
Expand Down

0 comments on commit 0d1b224

Please sign in to comment.