Skip to content

Commit

Permalink
[Feature] Add the support of voxelization op for ascend device (open-…
Browse files Browse the repository at this point in the history
…mmlab#2613)

* update

* update
  • Loading branch information
dflhw authored and CokeDong committed Apr 6, 2023
1 parent fd27378 commit ead6aa6
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 3 deletions.
2 changes: 1 addition & 1 deletion docs/en/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,5 @@ We implement common ops used in detection, segmentation, etc.
| ThreeNN | ||| | |
| TINShift | ||| | |
| UpFirDn2d | || | | |
| Voxelization |||| | |
| Voxelization |||| | |
| PrRoIPool | || | | |
2 changes: 1 addition & 1 deletion docs/zh_cn/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,5 @@ MMCV 提供了检测、分割等任务中常用的算子
| ThreeNN | ||| | |
| TINShift | ||| | |
| UpFirDn2d | || | | |
| Voxelization |||| | |
| Voxelization |||| | |
| PrRoIPool | || | | |
59 changes: 59 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/voxelization_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;
using namespace std;

int hard_voxelize_forward_impl(const at::Tensor &points, at::Tensor &voxels,
at::Tensor &coors,
at::Tensor &num_points_per_voxel,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int max_points, const int max_voxels,
const int NDim = 3);

int hard_voxelize_forward_npu(const at::Tensor &points, at::Tensor &voxels,
at::Tensor &coors,
at::Tensor &num_points_per_voxel,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int max_points, const int max_voxels,
const int NDim = 3) {
at::Tensor voxel_num_tmp = OpPreparation::ApplyTensor(points, {1});
at::Tensor voxel_num = at_npu::native::NPUNativeFunctions::npu_dtype_cast(
voxel_num_tmp, at::kInt);

at::Tensor voxel_size_cpu = at::from_blob(
const_cast<float *>(voxel_size.data()), {3}, dtype(at::kFloat));
at::Tensor voxel_size_npu =
CalcuOpUtil::CopyTensorHostToDevice(voxel_size_cpu);

at::Tensor coors_range_cpu = at::from_blob(
const_cast<float *>(coors_range.data()), {6}, dtype(at::kFloat));
at::Tensor coors_range_npu =
CalcuOpUtil::CopyTensorHostToDevice(coors_range_cpu);

int64_t max_points_ = (int64_t)max_points;
int64_t max_voxels_ = (int64_t)max_voxels;

// only support true now
bool deterministic = true;

OpCommand cmd;
cmd.Name("Voxelization")
.Input(points)
.Input(voxel_size_npu)
.Input(coors_range_npu)
.Output(voxels)
.Output(coors)
.Output(num_points_per_voxel)
.Output(voxel_num)
.Attr("max_points", max_points_)
.Attr("max_voxels", max_voxels_)
.Attr("deterministic", deterministic)
.Run();
auto voxel_num_cpu = voxel_num.to(at::kCPU);
int voxel_num_int = voxel_num_cpu.data_ptr<int>()[0];
return voxel_num_int;
}

REGISTER_NPU_IMPL(hard_voxelize_forward_impl, hard_voxelize_forward_npu);
37 changes: 36 additions & 1 deletion tests/test_ops/test_voxelization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

from mmcv.ops import Voxelization
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE


def _get_voxel_points_indices(points, coors, voxel):
Expand Down Expand Up @@ -172,3 +172,38 @@ def test_voxelization_mlu(device_type):
assert np.all(coors == expected_coors)
assert np.all(voxels == expected_voxels)
assert np.all(num_points_per_voxel == expected_num_points_per_voxel)


@pytest.mark.parametrize('device_type', [
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
def test_voxelization_npu(device_type):
voxel_size = [0.5, 0.5, 0.5]
point_cloud_range = [0, -40, -3, 70.4, 40, 1]

voxel_dict = np.load(
'tests/data/for_3d_ops/test_voxel.npy', allow_pickle=True).item()
expected_coors = voxel_dict['coors']
expected_voxels = voxel_dict['voxels']
expected_num_points_per_voxel = voxel_dict['num_points_per_voxel']
points = voxel_dict['points']

points = torch.tensor(points)
max_num_points = 1000
hard_voxelization = Voxelization(voxel_size, point_cloud_range,
max_num_points)

device = torch.device(device_type)

# test hard_voxelization on npu
points = points.contiguous().to(device)
coors, voxels, num_points_per_voxel = hard_voxelization.forward(points)
coors = coors.cpu().detach().numpy()
voxels = voxels.cpu().detach().numpy()
num_points_per_voxel = num_points_per_voxel.cpu().detach().numpy()
assert np.all(coors == expected_coors)
assert np.all(voxels == expected_voxels)
assert np.all(num_points_per_voxel == expected_num_points_per_voxel)

0 comments on commit ead6aa6

Please sign in to comment.