From 2dc6863bba0afb72d531fc78c58a3fd9a412e982 Mon Sep 17 00:00:00 2001 From: liuyuan1-v Date: Tue, 7 Mar 2023 15:29:39 +0800 Subject: [PATCH 1/7] [Feature] Support NmsRotated with cambricon MLU backend --- docs/en/understand_mmcv/ops.md | 2 +- mmcv/ops/csrc/pytorch/mlu/nms_rotated_mlu.cpp | 73 +++++++++++++++++++ mmcv/ops/csrc/pytorch/nms_rotated.cpp | 11 +++ mmcv/ops/nms.py | 9 +++ tests/test_ops/test_nms_rotated.py | 14 +++- 5 files changed, 105 insertions(+), 4 deletions(-) create mode 100644 mmcv/ops/csrc/pytorch/mlu/nms_rotated_mlu.cpp diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index 95cf94de5b..e7212bbddc 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -35,7 +35,7 @@ We implement common ops used in detection, segmentation, etc. | ModulatedDeformConv2d | √ | √ | √ | | √ | | MultiScaleDeformableAttn | | √ | √ | | | | NMS | √ | √ | √ | | √ | -| NMSRotated | √ | √ | | | √ | +| NMSRotated | √ | √ | √ | | √ | | NMSQuadri | √ | √ | | | | | PixelGroup | √ | | | | | | PointsInBoxes | √ | √ | | | | diff --git a/mmcv/ops/csrc/pytorch/mlu/nms_rotated_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/nms_rotated_mlu.cpp new file mode 100644 index 0000000000..4d4e53fccf --- /dev/null +++ b/mmcv/ops/csrc/pytorch/mlu/nms_rotated_mlu.cpp @@ -0,0 +1,73 @@ +/************************************************************************* + * Copyright (C) 2021 Cambricon. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include "mlu_common_helper.h" + +Tensor nms_rotated_mlu(Tensor boxes, Tensor scores, float iou_threshold) { + // dimension parameters check + TORCH_CHECK(boxes.dim() == 2, "boxes should be a 2d tensor, got ", + boxes.dim(), "D"); + TORCH_CHECK(boxes.size(1) == 5, + "boxes should have 5 elements in dimension 1, got ", + boxes.size(1)); + TORCH_CHECK(scores.dim() == 1, "scores should be a 1d tensor, got ", + scores.dim(), "D"); + + TORCH_CHECK(boxes.size(0) == scores.size(0), "boxes and scores should have", + " same elements in dimension 0, boxes got ", boxes.size(0), + ", scores got ", scores.size(0)); + + // data type check + TORCH_CHECK(boxes.scalar_type() == scores.scalar_type(), + "boxes should have the same type as scores"); + TORCH_CHECK(boxes.scalar_type() == at::kFloat, + "data type of boxes should be Float, got ", boxes.scalar_type()); + + if (boxes.numel() == 0) { + return at::empty({0}, boxes.options().dtype(at::kLong)); + } + + int boxes_num = boxes.size(0); + + auto boxes_ = torch_mlu::cnnl::ops::cnnl_contiguous(boxes); + auto scores_ = torch_mlu::cnnl::ops::cnnl_contiguous(scores); + auto output = at::empty({boxes_num}, boxes.options().dtype(at::kInt)); + auto output_size = at::empty({1}, scores.options().dtype(at::kInt)); + + MluOpTensorDescriptor boxes_desc, scores_desc, output_desc; + boxes_desc.set(boxes_); + scores_desc.set(scores_); + output_desc.set(output); + + // workspace + size_t workspace_size = 0; + auto handle = mluOpGetCurrentHandle(); + mluOpGetNmsRotatedWorkspaceSize(handle, boxes_desc.desc(), &workspace_size); + auto workspace = at::empty(workspace_size, boxes.options().dtype(at::kByte)); + + auto boxes_impl = torch_mlu::getMluTensorImpl(boxes_); + auto boxes_ptr = boxes_impl->cnnlMalloc(); + auto scores_impl = torch_mlu::getMluTensorImpl(scores_); + auto scores_ptr = scores_impl->cnnlMalloc(); + auto workspace_impl = torch_mlu::getMluTensorImpl(workspace); + auto workspace_ptr = workspace_impl->cnnlMalloc(); + auto output_impl = torch_mlu::getMluTensorImpl(output); + auto output_ptr = output_impl->cnnlMalloc(); + auto output_size_impl = torch_mlu::getMluTensorImpl(output_size); + auto output_size_ptr = output_size_impl->cnnlMalloc(); + + mluOpNmsRotated(handle, iou_threshold, boxes_desc.desc(), boxes_ptr, + scores_desc.desc(), scores_ptr, workspace_ptr, workspace_size, + output_desc.desc(), output_ptr, (int *)output_size_ptr); + int output_num = *static_cast(output_size.cpu().data_ptr()); + auto ret = output.to(boxes.options().dtype(at::kLong)); + return ret.slice(0, 0, output_num); +} diff --git a/mmcv/ops/csrc/pytorch/nms_rotated.cpp b/mmcv/ops/csrc/pytorch/nms_rotated.cpp index b07ed5aa11..092ac5edd2 100644 --- a/mmcv/ops/csrc/pytorch/nms_rotated.cpp +++ b/mmcv/ops/csrc/pytorch/nms_rotated.cpp @@ -17,6 +17,11 @@ Tensor nms_rotated_npu(const Tensor dets, const Tensor scores, const Tensor labels, const float iou_threshold); #endif +#ifdef MMCV_WITH_MLU +Tensor nms_rotated_mlu(const Tensor dets, const Tensor scores, + const float iou_threshold); +#endif + // Interface for Python // inline is needed to prevent multiple function definitions when this header is // included by different cpps @@ -36,6 +41,12 @@ Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order, return nms_rotated_npu(dets, scores, labels, iou_threshold); #else AT_ERROR("Not compiled with NPU support"); +#endif + } else if (dets.device().type() == at::kMLU) { +#ifdef MMCV_WITH_MLU + return nms_rotated_mlu(dets, scores, iou_threshold); +#else + AT_ERROR("Not compiled with MLU support"); #endif } diff --git a/mmcv/ops/nms.py b/mmcv/ops/nms.py index 00d22f2ac5..1729facee1 100644 --- a/mmcv/ops/nms.py +++ b/mmcv/ops/nms.py @@ -470,6 +470,15 @@ def nms_rotated(dets: Tensor, dim=1) return dets, keep_inds + if dets.device.type == 'mlu': + order = scores.new_empty(0, dtype=torch.long) + keep_inds = ext_module.nms_rotated(dets_cw, scores, order, dets_cw, + input_labels, iou_threshold, + multi_label) + dets = torch.cat((dets[keep_inds], scores[keep_inds].reshape(-1, 1)), + dim=1) + return dets, keep_inds + if multi_label: dets_wl = torch.cat((dets_cw, labels.unsqueeze(1)), 1) # type: ignore else: diff --git a/tests/test_ops/test_nms_rotated.py b/tests/test_ops/test_nms_rotated.py index bee562a6f1..35d884432f 100644 --- a/tests/test_ops/test_nms_rotated.py +++ b/tests/test_ops/test_nms_rotated.py @@ -3,7 +3,7 @@ import pytest import torch -from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE, IS_MLU_AVAILABLE class TestNmsRotated: @@ -16,7 +16,11 @@ class TestNmsRotated: pytest.param( 'cuda', marks=pytest.mark.skipif( - not IS_CUDA_AVAILABLE, reason='requires CUDA support')) + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'mlu', + marks=pytest.mark.skipif( + not IS_MLU_AVAILABLE, reason='requires MLU support')) ]) def test_ml_nms_rotated(self, device): from mmcv.ops import nms_rotated @@ -58,7 +62,11 @@ def test_ml_nms_rotated(self, device): pytest.param( 'cuda', marks=pytest.mark.skipif( - not IS_CUDA_AVAILABLE, reason='requires CUDA support')) + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'mlu', + marks=pytest.mark.skipif( + not IS_MLU_AVAILABLE, reason='requires MLU support')) ]) def test_nms_rotated(self, device): from mmcv.ops import nms_rotated From 58c1ed9a2446fe2e05eb52ca399e45273fb8f094 Mon Sep 17 00:00:00 2001 From: liuyuan1-v Date: Tue, 7 Mar 2023 17:14:01 +0800 Subject: [PATCH 2/7] [Feature] remove foolproofs in nms_rotated_mlu.cpp --- mmcv/ops/csrc/pytorch/mlu/nms_rotated_mlu.cpp | 20 ------------------- mmcv/ops/nms.py | 9 --------- 2 files changed, 29 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/mlu/nms_rotated_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/nms_rotated_mlu.cpp index 4d4e53fccf..9b45a17805 100644 --- a/mmcv/ops/csrc/pytorch/mlu/nms_rotated_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/nms_rotated_mlu.cpp @@ -12,31 +12,11 @@ #include "mlu_common_helper.h" Tensor nms_rotated_mlu(Tensor boxes, Tensor scores, float iou_threshold) { - // dimension parameters check - TORCH_CHECK(boxes.dim() == 2, "boxes should be a 2d tensor, got ", - boxes.dim(), "D"); - TORCH_CHECK(boxes.size(1) == 5, - "boxes should have 5 elements in dimension 1, got ", - boxes.size(1)); - TORCH_CHECK(scores.dim() == 1, "scores should be a 1d tensor, got ", - scores.dim(), "D"); - - TORCH_CHECK(boxes.size(0) == scores.size(0), "boxes and scores should have", - " same elements in dimension 0, boxes got ", boxes.size(0), - ", scores got ", scores.size(0)); - - // data type check - TORCH_CHECK(boxes.scalar_type() == scores.scalar_type(), - "boxes should have the same type as scores"); - TORCH_CHECK(boxes.scalar_type() == at::kFloat, - "data type of boxes should be Float, got ", boxes.scalar_type()); - if (boxes.numel() == 0) { return at::empty({0}, boxes.options().dtype(at::kLong)); } int boxes_num = boxes.size(0); - auto boxes_ = torch_mlu::cnnl::ops::cnnl_contiguous(boxes); auto scores_ = torch_mlu::cnnl::ops::cnnl_contiguous(scores); auto output = at::empty({boxes_num}, boxes.options().dtype(at::kInt)); diff --git a/mmcv/ops/nms.py b/mmcv/ops/nms.py index 1729facee1..00d22f2ac5 100644 --- a/mmcv/ops/nms.py +++ b/mmcv/ops/nms.py @@ -470,15 +470,6 @@ def nms_rotated(dets: Tensor, dim=1) return dets, keep_inds - if dets.device.type == 'mlu': - order = scores.new_empty(0, dtype=torch.long) - keep_inds = ext_module.nms_rotated(dets_cw, scores, order, dets_cw, - input_labels, iou_threshold, - multi_label) - dets = torch.cat((dets[keep_inds], scores[keep_inds].reshape(-1, 1)), - dim=1) - return dets, keep_inds - if multi_label: dets_wl = torch.cat((dets_cw, labels.unsqueeze(1)), 1) # type: ignore else: From fc1600daaaa372cd84e5b0c438132e9c3476d5f1 Mon Sep 17 00:00:00 2001 From: liuyuan1-v Date: Wed, 8 Mar 2023 10:27:37 +0800 Subject: [PATCH 3/7] [Feature] fix lint in test_nms_rotated.py --- tests/test_ops/test_nms_rotated.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_ops/test_nms_rotated.py b/tests/test_ops/test_nms_rotated.py index 35d884432f..88b41fec85 100644 --- a/tests/test_ops/test_nms_rotated.py +++ b/tests/test_ops/test_nms_rotated.py @@ -3,7 +3,7 @@ import pytest import torch -from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE, IS_MLU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE class TestNmsRotated: From 580c024bbbe2b8309b73cdd9717fb0a815f2a722 Mon Sep 17 00:00:00 2001 From: liuyuan1-v Date: Wed, 8 Mar 2023 12:00:48 +0800 Subject: [PATCH 4/7] [Feature] fix kMLU not found in nms_rotated.cpp --- mmcv/ops/csrc/pytorch/nms_rotated.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/nms_rotated.cpp b/mmcv/ops/csrc/pytorch/nms_rotated.cpp index 092ac5edd2..1d49c37dd6 100644 --- a/mmcv/ops/csrc/pytorch/nms_rotated.cpp +++ b/mmcv/ops/csrc/pytorch/nms_rotated.cpp @@ -42,11 +42,9 @@ Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order, #else AT_ERROR("Not compiled with NPU support"); #endif - } else if (dets.device().type() == at::kMLU) { #ifdef MMCV_WITH_MLU + } else if (dets.device().type() == at::kMLU) { return nms_rotated_mlu(dets, scores, iou_threshold); -#else - AT_ERROR("Not compiled with MLU support"); #endif } From c41d6b6f38b8d4de08339c943e342d76aa987f35 Mon Sep 17 00:00:00 2001 From: liuyuan1-v Date: Tue, 21 Mar 2023 20:30:18 +0800 Subject: [PATCH 5/7] [Feature] modify mlu support in nms.py --- mmcv/ops/nms.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mmcv/ops/nms.py b/mmcv/ops/nms.py index 00d22f2ac5..04fc50f3d0 100644 --- a/mmcv/ops/nms.py +++ b/mmcv/ops/nms.py @@ -458,11 +458,12 @@ def nms_rotated(dets: Tensor, input_labels = scores.new_empty(0, dtype=torch.int) else: input_labels = labels - if dets.device.type == 'npu': + if dets.device.type in ('npu', 'mlu'): order = scores.new_empty(0, dtype=torch.long) coefficient = 57.29578 # 180 / PI - for i in range(dets.size()[0]): - dets_cw[i][4] *= coefficient # radians to angle + if dets.device.type == 'npu': + for i in range(dets.size()[0]): + dets_cw[i][4] *= coefficient # radians to angle keep_inds = ext_module.nms_rotated(dets_cw, scores, order, dets_cw, input_labels, iou_threshold, multi_label) From 3bc0cedf10127dd9912be7c8758a07cc11b50240 Mon Sep 17 00:00:00 2001 From: liuyuan1-v Date: Wed, 22 Mar 2023 16:22:33 +0800 Subject: [PATCH 6/7] [Feature] modify nms_rotated support in ops.md --- docs/zh_cn/understand_mmcv/ops.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index b4ace828d8..81092144a7 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -35,7 +35,7 @@ MMCV 提供了检测、分割等任务中常用的算子 | ModulatedDeformConv2d | √ | √ | √ | | √ | | MultiScaleDeformableAttn | | √ | √ | | | | NMS | √ | √ | √ | | √ | -| NMSRotated | √ | √ | | | √ | +| NMSRotated | √ | √ | √ | | √ | | NMSQuadri | √ | √ | | | | | PixelGroup | √ | | | | | | PointsInBoxes | √ | √ | | | | From f8fd6deebc1509bbe17a403fbaa0e5de8fd00d89 Mon Sep 17 00:00:00 2001 From: liuyuan1-v Date: Thu, 23 Mar 2023 10:00:59 +0800 Subject: [PATCH 7/7] [Feature] modify ops/nms.py --- mmcv/ops/nms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmcv/ops/nms.py b/mmcv/ops/nms.py index 04fc50f3d0..5115a95f62 100644 --- a/mmcv/ops/nms.py +++ b/mmcv/ops/nms.py @@ -460,8 +460,8 @@ def nms_rotated(dets: Tensor, input_labels = labels if dets.device.type in ('npu', 'mlu'): order = scores.new_empty(0, dtype=torch.long) - coefficient = 57.29578 # 180 / PI if dets.device.type == 'npu': + coefficient = 57.29578 # 180 / PI for i in range(dets.size()[0]): dets_cw[i][4] *= coefficient # radians to angle keep_inds = ext_module.nms_rotated(dets_cw, scores, order, dets_cw,