Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add the implementation of diff_iou_rotated with mlu-ops #2852

Merged
merged 5 commits into from
Jun 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/en/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ We implement common ops used in detection, segmentation, etc.
| Correlation | | √ | | | |
| Deformable Convolution v1/v2 | √ | √ | √ | | √ |
| Deformable RoIPool | | √ | √ | | √ |
| DiffIoURotated | | √ | | | |
| DiffIoURotated | | √ | | | |
| DynamicScatter | | √ | | | |
| FurthestPointSample | | √ | | | |
| FurthestPointSampleWithDist | | √ | | | |
Expand Down
2 changes: 1 addition & 1 deletion docs/zh_cn/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| Correlation | | √ | | | |
| Deformable Convolution v1/v2 | √ | √ | √ | | √ |
| Deformable RoIPool | | √ | √ | | √ |
| DiffIoURotated | | √ | | | |
| DiffIoURotated | | √ | | | |
| DynamicScatter | | √ | | | |
| FurthestPointSample | | √ | | | |
| FurthestPointSampleWithDist | | √ | | | |
Expand Down
55 changes: 55 additions & 0 deletions mmcv/ops/csrc/pytorch/mlu/diff_iou_rotated_mlu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*************************************************************************
* Copyright (C) 2023 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "mlu_common_helper.h"

Tensor diff_iou_rotated_sort_vertices_forward_mlu(Tensor vertices, Tensor mask,
Tensor num_valid) {
// params check
TORCH_CHECK(vertices.scalar_type() == at::kFloat,
"vertices type should be Float, got ", vertices.scalar_type());
TORCH_CHECK(mask.scalar_type() == at::kBool, "mask should be Bool, got ",
mask.scalar_type());
TORCH_CHECK(num_valid.scalar_type() == at::kInt,
"num_valid type should be Int32, got ", num_valid.scalar_type());
TORCH_CHECK(vertices.size(2) == 24, "vertices.dim(2) should be 24, got ",
vertices.size(2));
TORCH_CHECK(mask.size(2) == 24, "mask.dim(2) should be 24, got ",
mask.size(2));

// zero-element check
if (vertices.numel() == 0) {
return at::empty({0}, num_valid.options().dtype(at::kInt));
}

auto idx = at::empty({vertices.size(0), vertices.size(1), 9},
num_valid.options().dtype(at::kInt));

INITIAL_MLU_PARAM_WITH_TENSOR(vertices);
INITIAL_MLU_PARAM_WITH_TENSOR(mask);
INITIAL_MLU_PARAM_WITH_TENSOR(num_valid);
INITIAL_MLU_PARAM_WITH_TENSOR(idx);

// get compute handle
auto handle = mluOpGetCurrentHandle();

// launch kernel
mluOpDiffIouRotatedSortVerticesForward(
handle, vertices_desc.desc(), vertices_ptr, mask_desc.desc(), mask_ptr,
num_valid_desc.desc(), num_valid_ptr, idx_desc.desc(), idx_ptr);
return idx;
}

Tensor diff_iou_rotated_sort_vertices_forward_impl(Tensor vertices, Tensor mask,
Tensor num_valid);

REGISTER_DEVICE_IMPL(diff_iou_rotated_sort_vertices_forward_impl, MLU,
diff_iou_rotated_sort_vertices_forward_mlu);
17 changes: 15 additions & 2 deletions mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,21 @@
#include "pytorch_device_registry.hpp"

#define MLUOP_MAJOR 0
#define MLUOP_MINOR 6
#define MLUOP_PATCHLEVEL 0
#define MLUOP_MINOR 7
#define MLUOP_PATCHLEVEL 1

/*************************************************************************
* This MACRO contains operations of simple tensor to mlu-tensor.
* _contiguous, _desc, _impl, _ptr will be automatically generated in
* this MACRO.
*************************************************************************/
#define INITIAL_MLU_PARAM_WITH_TENSOR(NAME) \
auto NAME##_contigous = torch_mlu::cnnl::ops::cnnl_contiguous( \
NAME, NAME.suggest_memory_format()); \
MluOpTensorDescriptor NAME##_desc; \
NAME##_desc.set(NAME##_contigous); \
auto NAME##_impl = torch_mlu::getMluTensorImpl(NAME##_contigous); \
auto NAME##_ptr = NAME##_impl->cnnlMalloc();

mluOpDataType_t getMluOpDataType(const caffe2::TypeMeta& data_type);
mluOpTensorLayout_t getMluOpSuggestLayout(const at::Tensor& input);
Expand Down
13 changes: 0 additions & 13 deletions mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,6 @@
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"

/*************************************************************************
* This MACRO contains operations of simple tensor to mlu-tensor.
* _contiguous, _desc, _impl, _ptr will be automatically generated in
* this MACRO.
*************************************************************************/
#define INITIAL_MLU_PARAM_WITH_TENSOR(NAME) \
auto NAME##_contigous = torch_mlu::cnnl::ops::cnnl_contiguous( \
NAME, NAME.suggest_memory_format()); \
MluOpTensorDescriptor NAME##_desc; \
NAME##_desc.set(NAME##_contigous); \
auto NAME##_impl = torch_mlu::getMluTensorImpl(NAME##_contigous); \
auto NAME##_ptr = NAME##_impl->cnnlMalloc();

Tensor MsDeformAttnForwardLauncher(const Tensor& value,
const Tensor& spatial_shapes,
const Tensor& level_start_index,
Expand Down
40 changes: 30 additions & 10 deletions tests/test_ops/test_diff_iou_rotated.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,23 @@
import torch

from mmcv.ops import diff_iou_rotated_2d, diff_iou_rotated_3d
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE

if IS_MLU_AVAILABLE:
torch.backends.mlu.matmul.allow_tf32 = False

@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_diff_iou_rotated_2d():

@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
def test_diff_iou_rotated_2d(device):
np_boxes1 = np.asarray([[[0.5, 0.5, 1., 1., .0], [0.5, 0.5, 1., 1., .0],
[0.5, 0.5, 1., 1., .0], [0.5, 0.5, 1., 1., .0],
[0.5, 0.5, 1., 1., .0]]],
Expand All @@ -19,17 +31,25 @@ def test_diff_iou_rotated_2d():
[1.5, 1.5, 1., 1., .0]]],
dtype=np.float32)

boxes1 = torch.from_numpy(np_boxes1).cuda()
boxes2 = torch.from_numpy(np_boxes2).cuda()
boxes1 = torch.from_numpy(np_boxes1).to(device)
boxes2 = torch.from_numpy(np_boxes2).to(device)

np_expect_ious = np.asarray([[1., 1., .7071, 1 / 7, .0]])
ious = diff_iou_rotated_2d(boxes1, boxes2)
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4)


@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_diff_iou_rotated_3d():
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
def test_diff_iou_rotated_3d(device):
np_boxes1 = np.asarray(
[[[.5, .5, .5, 1., 1., 1., .0], [.5, .5, .5, 1., 1., 1., .0],
[.5, .5, .5, 1., 1., 1., .0], [.5, .5, .5, 1., 1., 1., .0],
Expand All @@ -41,8 +61,8 @@ def test_diff_iou_rotated_3d():
[-1.5, -1.5, -1.5, 2.5, 2.5, 2.5, .0]]],
dtype=np.float32)

boxes1 = torch.from_numpy(np_boxes1).cuda()
boxes2 = torch.from_numpy(np_boxes2).cuda()
boxes1 = torch.from_numpy(np_boxes1).to(device)
boxes2 = torch.from_numpy(np_boxes2).to(device)

np_expect_ious = np.asarray([[1., .5, .7071, 1 / 15, .0]])
ious = diff_iou_rotated_3d(boxes1, boxes2)
Expand Down