diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index 2ee4380aed..ac9b36a5be 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -58,5 +58,5 @@ We implement common ops used in detection, segmentation, etc. | ThreeNN | | √ | √ | | | | TINShift | | √ | √ | | | | UpFirDn2d | | √ | | | | -| Voxelization | √ | √ | √ | | | +| Voxelization | √ | √ | √ | | √ | | PrRoIPool | | √ | | | | diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index 862946e7c4..70179d05d2 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -58,5 +58,5 @@ MMCV 提供了检测、分割等任务中常用的算子 | ThreeNN | | √ | √ | | | | TINShift | | √ | √ | | | | UpFirDn2d | | √ | | | | -| Voxelization | √ | √ | √ | | | +| Voxelization | √ | √ | √ | | √ | | PrRoIPool | | √ | | | | diff --git a/mmcv/ops/csrc/pytorch/npu/voxelization_npu.cpp b/mmcv/ops/csrc/pytorch/npu/voxelization_npu.cpp new file mode 100644 index 0000000000..13e50401f9 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/voxelization_npu.cpp @@ -0,0 +1,59 @@ +#include "pytorch_npu_helper.hpp" + +using namespace NPU_NAME_SPACE; +using namespace std; + +int hard_voxelize_forward_impl(const at::Tensor &points, at::Tensor &voxels, + at::Tensor &coors, + at::Tensor &num_points_per_voxel, + const std::vector voxel_size, + const std::vector coors_range, + const int max_points, const int max_voxels, + const int NDim = 3); + +int hard_voxelize_forward_npu(const at::Tensor &points, at::Tensor &voxels, + at::Tensor &coors, + at::Tensor &num_points_per_voxel, + const std::vector voxel_size, + const std::vector coors_range, + const int max_points, const int max_voxels, + const int NDim = 3) { + at::Tensor voxel_num_tmp = OpPreparation::ApplyTensor(points, {1}); + at::Tensor voxel_num = at_npu::native::NPUNativeFunctions::npu_dtype_cast( + voxel_num_tmp, at::kInt); + + at::Tensor voxel_size_cpu = at::from_blob( + const_cast(voxel_size.data()), {3}, dtype(at::kFloat)); + at::Tensor voxel_size_npu = + CalcuOpUtil::CopyTensorHostToDevice(voxel_size_cpu); + + at::Tensor coors_range_cpu = at::from_blob( + const_cast(coors_range.data()), {6}, dtype(at::kFloat)); + at::Tensor coors_range_npu = + CalcuOpUtil::CopyTensorHostToDevice(coors_range_cpu); + + int64_t max_points_ = (int64_t)max_points; + int64_t max_voxels_ = (int64_t)max_voxels; + + // only support true now + bool deterministic = true; + + OpCommand cmd; + cmd.Name("Voxelization") + .Input(points) + .Input(voxel_size_npu) + .Input(coors_range_npu) + .Output(voxels) + .Output(coors) + .Output(num_points_per_voxel) + .Output(voxel_num) + .Attr("max_points", max_points_) + .Attr("max_voxels", max_voxels_) + .Attr("deterministic", deterministic) + .Run(); + auto voxel_num_cpu = voxel_num.to(at::kCPU); + int voxel_num_int = voxel_num_cpu.data_ptr()[0]; + return voxel_num_int; +} + +REGISTER_NPU_IMPL(hard_voxelize_forward_impl, hard_voxelize_forward_npu); diff --git a/tests/test_ops/test_voxelization.py b/tests/test_ops/test_voxelization.py index d34797caf2..cd01eb46e6 100644 --- a/tests/test_ops/test_voxelization.py +++ b/tests/test_ops/test_voxelization.py @@ -4,7 +4,7 @@ import torch from mmcv.ops import Voxelization -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE def _get_voxel_points_indices(points, coors, voxel): @@ -172,3 +172,38 @@ def test_voxelization_mlu(device_type): assert np.all(coors == expected_coors) assert np.all(voxels == expected_voxels) assert np.all(num_points_per_voxel == expected_num_points_per_voxel) + + +@pytest.mark.parametrize('device_type', [ + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')) +]) +def test_voxelization_npu(device_type): + voxel_size = [0.5, 0.5, 0.5] + point_cloud_range = [0, -40, -3, 70.4, 40, 1] + + voxel_dict = np.load( + 'tests/data/for_3d_ops/test_voxel.npy', allow_pickle=True).item() + expected_coors = voxel_dict['coors'] + expected_voxels = voxel_dict['voxels'] + expected_num_points_per_voxel = voxel_dict['num_points_per_voxel'] + points = voxel_dict['points'] + + points = torch.tensor(points) + max_num_points = 1000 + hard_voxelization = Voxelization(voxel_size, point_cloud_range, + max_num_points) + + device = torch.device(device_type) + + # test hard_voxelization on npu + points = points.contiguous().to(device) + coors, voxels, num_points_per_voxel = hard_voxelization.forward(points) + coors = coors.cpu().detach().numpy() + voxels = voxels.cpu().detach().numpy() + num_points_per_voxel = num_points_per_voxel.cpu().detach().numpy() + assert np.all(coors == expected_coors) + assert np.all(voxels == expected_voxels) + assert np.all(num_points_per_voxel == expected_num_points_per_voxel)