Skip to content

Commit

Permalink
Pick changes from 1.x branch (open-mmlab#2738)
Browse files Browse the repository at this point in the history
  • Loading branch information
ckirchhoff2021 authored and akozlov-outrider committed May 8, 2023
1 parent 0c8cd33 commit 9bd4741
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 25 deletions.
2 changes: 1 addition & 1 deletion docs/en/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ We implement common ops used in detection, segmentation, etc.
| RoIPool | ||| ||
| RoIAlignRotated |||| | |
| RiRoIAlignRotated | || | | |
| RoIAlign |||| | |
| RoIAlign |||| | |
| RoIAwarePool3d | ||| | |
| SAConv2d | || | | |
| SigmoidFocalLoss | ||| ||
Expand Down
2 changes: 1 addition & 1 deletion docs/zh_cn/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| RoIPool | ||| ||
| RoIAlignRotated |||| | |
| RiRoIAlignRotated | || | | |
| RoIAlign |||| | |
| RoIAlign |||| | |
| RoIAwarePool3d | ||| | |
| SAConv2d | || | | |
| SigmoidFocalLoss | ||| ||
Expand Down
32 changes: 21 additions & 11 deletions mmcv/ops/csrc/pytorch/npu/bbox_overlaps_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,33 @@ void bbox_overlaps_npu(const Tensor bboxes1, const Tensor bboxes2, Tensor ious,
if (mode == 1) {
modeStr = "iof";
}
float offset_ = 1;
if (offset == 0) {
offset_ = 0.01;
at::Tensor bboxesFP32 = bboxes2;
at::Tensor gtboxesFP32 = bboxes1;
if (bboxes2.scalar_type() != at::ScalarType::Float) {
bboxesFP32 = NPUNativeFunctions::npu_dtype_cast(bboxes2, at::kFloat);
gtboxesFP32 = NPUNativeFunctions::npu_dtype_cast(bboxes1, at::kFloat);
}
at::Tensor bboxes = at::ones_like(bboxes2);
at::Tensor gtboxes = at::ones_like(bboxes1);
bboxes = aligned ? bboxes2.transpose(0, 1) : bboxes2;
gtboxes = aligned ? bboxes1.transpose(0, 1) : bboxes1;
c10::SmallVector<int64_t, SIZE> iousSize = {gtboxesFP32.size(0),
bboxesFP32.size(0)};
if (aligned) {
iousSize = {gtboxesFP32.size(0), 1};
}
at::Tensor iousFP32 = OpPreparation::ApplyTensor(bboxesFP32, iousSize);
bboxesFP32 = aligned ? bboxesFP32.transpose(0, 1) : bboxesFP32;
gtboxesFP32 = aligned ? gtboxesFP32.transpose(0, 1) : gtboxesFP32;
OpCommand cmd;
cmd.Name("Iou")
.Input(bboxes)
.Input(gtboxes)
.Output(ious)
.Input(bboxesFP32)
.Input(gtboxesFP32)
.Output(iousFP32)
.Attr("mode", modeStr)
.Attr("eps", offset_)
.Attr("eps", (float)offset)
.Attr("aligned", aligned)
.Run();
if (bboxes2.scalar_type() != at::ScalarType::Float) {
iousFP32 = NPUNativeFunctions::npu_dtype_cast(iousFP32, at::kHalf);
}
ious.copy_(iousFP32);
}

REGISTER_NPU_IMPL(bbox_overlaps_impl, bbox_overlaps_npu);
8 changes: 4 additions & 4 deletions mmcv/ops/csrc/pytorch/npu/nms_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@ using namespace NPU_NAME_SPACE;
using namespace std;

Tensor nms_npu(Tensor boxes, Tensor scores, float iou_threshold, int offset) {
at::Tensor boxed_offest = at_npu::native::OpPreparation::ApplyTensor(boxes);
at::Tensor ones_tensor =
at_npu::native::OpPreparation::ApplyTensor(boxes).fill_(1);
at::add_out(boxed_offest, boxes, ones_tensor, offset);
TORCH_CHECK((boxes.scalar_type() == at::ScalarType::Float),
"The type of boxes tensor passed in nms_npu should be float");
int64_t offset_64 = offset;
at::Tensor iou_threshold_y = at_npu::native::OpPreparation::ApplyTensor(
{}, boxes.options().dtype(at::kFloat), boxes)
.fill_(iou_threshold);
Expand All @@ -29,6 +28,7 @@ Tensor nms_npu(Tensor boxes, Tensor scores, float iou_threshold, int offset) {
.Input(max_outputsize_y)
.Input(iou_threshold_y)
.Input(scores_threshold_y)
.Attr("offset", offset_64)
.Output(output)
.Run();
auto outputsizeBool = at::gt(output, -1);
Expand Down
68 changes: 68 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/roi_align_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;
using namespace std;

void roi_align_forward_npu(Tensor input, Tensor rois, Tensor output,
Tensor argmax_y, Tensor argmax_x, int aligned_height,
int aligned_width, float spatial_scale,
int sampling_ratio, int pool_mode, bool aligned) {
if (!aligned) {
LOG(WARNING) << "The [aligned] attr in roi_align op is false";
}
int64_t aligned_height_64 = aligned_height;
int64_t aligned_width_64 = aligned_width;
int64_t sampling_ratio_64 = sampling_ratio;
int64_t roi_end_mode = 0;
OpCommand cmd;
cmd.Name("ROIAlign")
.Input(input)
.Input(rois)
.Output(output)
.Attr("spatial_scale", spatial_scale)
.Attr("pooled_height", aligned_height_64)
.Attr("pooled_width", aligned_width_64)
.Attr("sample_num", sampling_ratio_64)
.Attr("roi_end_mode", roi_end_mode)
.Run();
}

void roi_align_backward_npu(Tensor grad_output, Tensor rois, Tensor argmax_y,
Tensor argmax_x, Tensor grad_input,
int aligned_height, int aligned_width,
float spatial_scale, int sampling_ratio,
int pool_mode, bool aligned) {
int64_t aligned_height_64 = aligned_height;
int64_t aligned_width_64 = aligned_width;
int64_t sampling_ratio_64 = sampling_ratio;
int64_t roi_end_mode = 0;
c10::SmallVector<int64_t, SIZE> xdiff_shape =
at_npu::native::array_to_small_vector(grad_input.sizes());
OpCommand cmd;
cmd.Name("ROIAlignGrad")
.Input(grad_output)
.Input(rois)
.Output(grad_input)
.Attr("xdiff_shape", xdiff_shape)
.Attr("pooled_width", aligned_width_64)
.Attr("pooled_height", aligned_height_64)
.Attr("spatial_scale", spatial_scale)
.Attr("sample_num", sampling_ratio_64)
.Attr("roi_end_mode", roi_end_mode)
.Run();
}

void roi_align_forward_impl(Tensor input, Tensor rois, Tensor output,
Tensor argmax_y, Tensor argmax_x,
int aligned_height, int aligned_width,
float spatial_scale, int sampling_ratio,
int pool_mode, bool aligned);

void roi_align_backward_impl(Tensor grad_output, Tensor rois, Tensor argmax_y,
Tensor argmax_x, Tensor grad_input,
int aligned_height, int aligned_width,
float spatial_scale, int sampling_ratio,
int pool_mode, bool aligned);

REGISTER_NPU_IMPL(roi_align_forward_impl, roi_align_forward_npu);
REGISTER_NPU_IMPL(roi_align_backward_impl, roi_align_backward_npu);
6 changes: 2 additions & 4 deletions mmcv/ops/csrc/pytorch/npu/voxelization_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,11 @@ int hard_voxelize_forward_npu(const at::Tensor &points, at::Tensor &voxels,

at::Tensor voxel_size_cpu = at::from_blob(
const_cast<float *>(voxel_size.data()), {3}, dtype(at::kFloat));
at::Tensor voxel_size_npu =
CalcuOpUtil::CopyTensorHostToDevice(voxel_size_cpu);
at::Tensor voxel_size_npu = voxel_size_cpu.to(points.device());

at::Tensor coors_range_cpu = at::from_blob(
const_cast<float *>(coors_range.data()), {6}, dtype(at::kFloat));
at::Tensor coors_range_npu =
CalcuOpUtil::CopyTensorHostToDevice(coors_range_cpu);
at::Tensor coors_range_npu = coors_range_cpu.to(points.device());

int64_t max_points_ = (int64_t)max_points;
int64_t max_voxels_ = (int64_t)max_voxels;
Expand Down
12 changes: 8 additions & 4 deletions tests/test_ops/test_roi_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import torch

from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE

_USING_PARROTS = True
try:
Expand Down Expand Up @@ -102,15 +102,19 @@ def _test_roialign_allclose(device, dtype):
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
not IS_MLU_AVAILABLE, reason='requires MLU support')),
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
@pytest.mark.parametrize('dtype', [
torch.float,
pytest.param(
torch.double,
marks=pytest.mark.skipif(
IS_MLU_AVAILABLE,
reason='MLU does not support for 64-bit floating point')),
IS_MLU_AVAILABLE or IS_NPU_AVAILABLE,
reason='MLU and NPU do not support for 64-bit floating point')),
torch.half
])
def test_roialign(device, dtype):
Expand Down

0 comments on commit 9bd4741

Please sign in to comment.