From 64badfe1b052743df435884abfe2cdf9fe939537 Mon Sep 17 00:00:00 2001 From: duzekun Date: Wed, 15 Mar 2023 10:38:19 +0800 Subject: [PATCH 1/6] [Feature] Add sparse convolution MLU API --- mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h | 4 +- mmcv/ops/csrc/pytorch/mlu/sparse_conv_mlu.cpp | 435 ++++++++++++++++++ mmcv/ops/csrc/pytorch/spconv_ops.cpp | 26 ++ setup.py | 2 +- tests/test_ops/test_spconv.py | 50 +- 5 files changed, 495 insertions(+), 22 deletions(-) create mode 100644 mmcv/ops/csrc/pytorch/mlu/sparse_conv_mlu.cpp diff --git a/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h b/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h index 0d6a9aff48..678dc52029 100644 --- a/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h +++ b/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h @@ -18,8 +18,8 @@ #include "pytorch_device_registry.hpp" #define MLUOP_MAJOR 0 -#define MLUOP_MINOR 4 -#define MLUOP_PATCHLEVEL 2 +#define MLUOP_MINOR 5 +#define MLUOP_PATCHLEVEL 0 mluOpDataType_t getMluOpDataType(const caffe2::TypeMeta& data_type); mluOpTensorLayout_t getMluOpSuggestLayout(const at::Tensor& input); diff --git a/mmcv/ops/csrc/pytorch/mlu/sparse_conv_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/sparse_conv_mlu.cpp new file mode 100644 index 0000000000..f32a9d6ffa --- /dev/null +++ b/mmcv/ops/csrc/pytorch/mlu/sparse_conv_mlu.cpp @@ -0,0 +1,435 @@ +/************************************************************************* + * Copyright (C) 2022 Cambricon. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include "mlu_common_helper.h" +#include "pytorch_device_registry.hpp" +#include "pytorch_mlu_helper.hpp" + +#include +#include + +template +std::vector GetIndicePairsForwardMLUKernelLauncher( + torch::Tensor indices, int64_t batchSize, + std::vector outSpatialShape, std::vector spatialShape, + std::vector kernelSize, std::vector stride, + std::vector padding, std::vector dilation, + std::vector outPadding, int64_t _subM, int64_t _transpose) { + + // The following code is copied from + // mmcv/ops/csrc/pytorch/cuda/spconv_ops_cuda.cu to ensure the output is + // available for network train. The outputs of this function have correct + // shape but wrong value. + auto numAct = indices.size(0); + auto kernelVolume = kernelSize[0]; + int sub_m = (int)_subM; + int transpose = (int)_transpose; + int batch = (int)batchSize; + auto coorDim = indices.size(1) - 1; + + for (int i = 1; i < kernelSize.size(); ++i) { + kernelVolume *= kernelSize[i]; + } + + auto outputVolume = outSpatialShape[0]; + for (int i = 1; i < outSpatialShape.size(); ++i) { + outputVolume *= outSpatialShape[i]; + } + torch::Tensor indicePairs = at::full({kernelVolume, 2, numAct}, -1, + indices.options().dtype(at::kInt)); + torch::Tensor indiceNum = at::zeros({kernelVolume}, + indices.options().dtype(at::kInt)); + int out_size = sub_m == 1 ? + numAct : std::min(numAct * kernelVolume, batch * outputVolume); + torch::Tensor out_indices = at::zeros({out_size, coorDim + 1}, + indices.options().dtype(at::kInt)); + auto indices_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + indices, at::MemoryFormat::Contiguous); + auto indicePairs_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + indicePairs, at::MemoryFormat::Contiguous); + auto indiceNum_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + indiceNum, at::MemoryFormat::Contiguous); + auto out_indices_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + out_indices, at::MemoryFormat::Contiguous); + + std::vector input_space; + std::vector filter_space; + std::vector output_space; + std::vector padding32; + std::vector stride32; + std::vector dilation32; + for (int i = 0; i < NDim; i++) { + input_space.push_back(spatialShape[i]); + filter_space.push_back(kernelSize[i]); + output_space.push_back(outSpatialShape[i]); + padding32.push_back(padding[i]); + stride32.push_back(stride[i]); + dilation32.push_back(dilation[i]); + } + MluOpTensorDescriptor indices_desc, out_indices_desc, + indicePairs_desc, indiceNum_desc; + indices_desc.set(indices_contiguous); + indicePairs_desc.set(indicePairs_contiguous); + indiceNum_desc.set(indiceNum_contiguous); + out_indices_desc.set(out_indices_contiguous); + { + mluOpTensorLayout_t layout = MLUOP_LAYOUT_ARRAY; + mluOpDataType_t dtype = MLUOP_DTYPE_INT32; + std::vector dims; + dims = {numAct, coorDim + 1}; + mluOpSetTensorDescriptor(indices_desc.desc(), layout, dtype, + dims.size(), dims.data()); + dims = {kernelVolume, 2, numAct}; + mluOpSetTensorDescriptor(indicePairs_desc.desc(), layout, dtype, + dims.size(), dims.data()); + dims = {kernelVolume}; + mluOpSetTensorDescriptor(indiceNum_desc.desc(), layout, dtype, + dims.size(), dims.data()); + dims = {out_size, coorDim + 1}; + mluOpSetTensorDescriptor(out_indices_desc.desc(), layout, dtype, + dims.size(), dims.data()); + } + + mluOpSparseConvolutionDescriptor_t sparse_conv_desc; + mluOpCreateSparseConvolutionDescriptor(&sparse_conv_desc); + mluOpSetSparseConvolutionDescriptor(sparse_conv_desc, NDim + 2, batch, + padding32.data(), stride32.data(), dilation32.data(), input_space.data(), + filter_space.data(), output_space.data(), sub_m, transpose, 0); + + auto handle = mluOpGetCurrentHandle(); + size_t workspace_size = 0; + mluOpGetIndicePairsWorkspaceSize(handle, sparse_conv_desc, indices_desc.desc(), + indicePairs_desc.desc(), out_indices_desc.desc(), indiceNum_desc.desc(), + &workspace_size); + auto indice_workspace_size = at::empty(workspace_size, + indices.options().dtype(at::kByte)); + + auto indices_impl = torch_mlu::getMluTensorImpl(indices_contiguous); + auto out_indices_impl = torch_mlu::getMluTensorImpl(out_indices_contiguous); + auto indicePairs_impl = torch_mlu::getMluTensorImpl(indicePairs_contiguous); + auto indiceNum_impl = torch_mlu::getMluTensorImpl(indiceNum_contiguous); + auto indice_workspace_impl = torch_mlu::getMluTensorImpl(indice_workspace_size); + + auto indices_ptr = indices_impl->cnnlMalloc(); + auto out_indices_ptr = out_indices_impl->cnnlMalloc(); + auto indicePairs_ptr = indicePairs_impl->cnnlMalloc(); + auto indiceNum_ptr = indiceNum_impl->cnnlMalloc(); + auto indice_workspace_ptr = indice_workspace_impl->cnnlMalloc(); + + mluOpGetIndicePairs(handle, sparse_conv_desc, indices_desc.desc(), indices_ptr, + indice_workspace_ptr, workspace_size, indicePairs_desc.desc(), + indicePairs_ptr, out_indices_desc.desc(), out_indices_ptr, + indiceNum_desc.desc(), indiceNum_ptr); + int num_act_out = 0; + mluOpGetSparseConvolutionNumActOut(sparse_conv_desc, &num_act_out); + mluOpDestroySparseConvolutionDescriptor(sparse_conv_desc); + if (!sub_m) { + return {out_indices.slice(0, 0, num_act_out), indicePairs, indiceNum}; + } else { + return {indices, indicePairs, indiceNum}; + } +} + +torch::Tensor IndiceConvForwardMLUKernelLauncher( + torch::Tensor features, torch::Tensor filters, torch::Tensor indicePairs, + torch::Tensor indiceNum, int64_t numActOut, int64_t _inverse, + int64_t _subM) { + auto indice_num_cpu = indiceNum.to({torch::kCPU}); + auto indice_num_cpu_64 = indice_num_cpu.data_ptr(); + int indice_num_len = indiceNum.numel(); + int64_t indice_num[indice_num_len]; + for (int i = 0; i < indice_num_len; ++i) { + indice_num[i] = (int64_t)(((int *)indice_num_cpu_64)[i]); + } + + // generate empty output + int C = filters.dim() == 4 ? filters.size(3) : filters.size(4); + torch::Tensor output = + at::zeros({numActOut, C}, features.options().dtype(at::kFloat)); + // generate descriptor + auto features_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + features, at::MemoryFormat::Contiguous); + auto filters_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + filters, at::MemoryFormat::Contiguous); + auto indice_pairs_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + indicePairs, at::MemoryFormat::Contiguous); + auto output_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + output, at::MemoryFormat::Contiguous); + + MluOpTensorDescriptor features_desc, filters_desc, indice_pairs_desc, + output_desc; + features_desc.set(features_contiguous); + filters_desc.set(filters_contiguous); + indice_pairs_desc.set(indice_pairs_contiguous); + output_desc.set(output_contiguous); + + // set layout + { + mluOpTensorLayout_t layout; + mluOpDataType_t dtype; + int dim; + int dims[8]; + + // features_desc + mluOpGetTensorDescriptor(features_desc.desc(), &layout, &dtype, &dim, dims); + mluOpSetTensorDescriptor(features_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, dim, dims); + + // filters_desc + mluOpGetTensorDescriptor(filters_desc.desc(), &layout, &dtype, &dim, dims); + mluOpSetTensorDescriptor(filters_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, dim, dims); + + // indice_pairs_desc + mluOpGetTensorDescriptor(indice_pairs_desc.desc(), &layout, &dtype, &dim, dims); + mluOpSetTensorDescriptor(indice_pairs_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, dim, dims); + + // output_desc + mluOpGetTensorDescriptor(output_desc.desc(), &layout, &dtype, &dim, dims); + mluOpSetTensorDescriptor(output_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, dim, dims); + } + + auto handle = mluOpGetCurrentHandle(); + size_t workspace_size = 0; + mluOpGetIndiceConvolutionForwardWorkspaceSize( + handle, features_desc.desc(), filters_desc.desc(), indice_pairs_desc.desc(), + output_desc.desc(), indice_num, numActOut, _inverse, _subM, &workspace_size); + + auto workspace = + at::empty(workspace_size, features.options().dtype(at::kByte)); + + auto features_impl = torch_mlu::getMluTensorImpl(features_contiguous); + auto filters_impl = torch_mlu::getMluTensorImpl(filters_contiguous); + auto indice_pairs_impl = torch_mlu::getMluTensorImpl(indice_pairs_contiguous); + auto workspace_impl = torch_mlu::getMluTensorImpl(workspace); + + auto features_ptr = features_impl->cnnlMalloc(); + auto filters_ptr = filters_impl->cnnlMalloc(); + auto indice_pairs_ptr = indice_pairs_impl->cnnlMalloc(); + auto workspace_ptr = workspace_impl->cnnlMalloc(); + + // outputs + auto output_impl = torch_mlu::getMluTensorImpl(output); + auto output_ptr = output_impl->cnnlMalloc(); + mluOpIndiceConvolutionForward( + handle, features_desc.desc(), features_ptr, filters_desc.desc(), filters_ptr, + indice_pairs_desc.desc(), indice_pairs_ptr, indice_num, numActOut, _inverse, _subM, + workspace_ptr, workspace_size, output_desc.desc(), output_ptr); + + return output; +} + +std::vector IndiceConvBackwardMLUKernelLauncher( + torch::Tensor features, torch::Tensor filters, torch::Tensor outGrad, + torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t _inverse, + int64_t _subM) { + auto indice_num_cpu = indiceNum.to({torch::kCPU}); + auto indice_num_cpu_64 = indice_num_cpu.data_ptr(); + int indice_num_len = indiceNum.numel(); + int64_t indice_num[indice_num_len]; + for (int i = 0; i < indice_num_len; ++i) { + indice_num[i] = (int64_t)(((int *)(indice_num_cpu_64))[i]); + } + + // generate empty input_grad + torch::Tensor input_grad = at::zeros({features.size(0), features.size(1)}, + features.options().dtype(at::kFloat)); + torch::Tensor filters_grad; + if (filters.dim() == 4) { + int h = filters.size(0); + int w = filters.size(1); + int c = filters.size(2); + int n = filters.size(3); + filters_grad = at::zeros({h, w, c, n}, filters.options().dtype(at::kFloat)); + } else if (filters.dim() == 5) { + int d = filters.size(0); + int h = filters.size(1); + int w = filters.size(2); + int c = filters.size(3); + int n = filters.size(4); + filters_grad = + at::zeros({d, h, w, c, n}, filters.options().dtype(at::kFloat)); + } + + auto features_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + features, at::MemoryFormat::Contiguous); + auto filters_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + filters, at::MemoryFormat::Contiguous); + auto output_grad_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + outGrad, at::MemoryFormat::Contiguous); + auto indice_pairs_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + indicePairs, at::MemoryFormat::Contiguous); + auto input_grad_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + features, at::MemoryFormat::Contiguous); + auto filters_grad_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + filters, at::MemoryFormat::Contiguous); + + MluOpTensorDescriptor features_desc, output_grad_desc, filters_desc, + indice_pairs_desc, input_grad_desc, filters_grad_desc; + features_desc.set(features_contiguous); + filters_desc.set(filters_contiguous); + output_grad_desc.set(output_grad_contiguous); + indice_pairs_desc.set(indice_pairs_contiguous); + input_grad_desc.set(input_grad_contiguous); + filters_grad_desc.set(filters_grad_contiguous); + + // need to set desc layout with mluOp functions + { + mluOpTensorLayout_t layout; + mluOpDataType_t dtype; + int dim; + int dims[8]; + + // features_desc + mluOpGetTensorDescriptor(features_desc.desc(), &layout, &dtype, &dim, dims); + mluOpSetTensorDescriptor(features_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, + dim, dims); + + // filters_desc + mluOpGetTensorDescriptor(filters_desc.desc(), &layout, &dtype, &dim, dims); + if (dim == 4) { + mluOpSetTensorDescriptor(filters_desc.desc(), MLUOP_LAYOUT_HWCN, dtype, + dim, dims); + } else { + mluOpSetTensorDescriptor(filters_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, + dim, dims); + } + + // output_grad_desc + mluOpGetTensorDescriptor(output_grad_desc.desc(), &layout, &dtype, &dim, + dims); + mluOpSetTensorDescriptor(output_grad_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, + dim, dims); + + // indice_pairs_desc + mluOpGetTensorDescriptor(indice_pairs_desc.desc(), &layout, &dtype, &dim, + dims); + mluOpSetTensorDescriptor(indice_pairs_desc.desc(), MLUOP_LAYOUT_ARRAY, + dtype, dim, dims); + + // input_grad_desc + mluOpGetTensorDescriptor(input_grad_desc.desc(), &layout, &dtype, &dim, + dims); + mluOpSetTensorDescriptor(input_grad_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, + dim, dims); + } + + auto handle = mluOpGetCurrentHandle(); + size_t data_workspace_size = 0; + mluOpGetIndiceConvolutionBackwardDataWorkspaceSize( + handle, output_grad_desc.desc(), filters_desc.desc(), + indice_pairs_desc.desc(), input_grad_desc.desc(), indice_num, _inverse, + &data_workspace_size); + + size_t filters_workspace_size = 0; + mluOpGetIndiceConvolutionBackwardFilterWorkspaceSize( + handle, features_desc.desc(), output_grad_desc.desc(), + indice_pairs_desc.desc(), filters_grad_desc.desc(), indice_num, _inverse, + _subM, &filters_workspace_size); + + auto indice_convbpdata_workspace = + at::empty(data_workspace_size, features.options().dtype(at::kByte)); + auto indice_convbpfilter_workspace = + at::empty(filters_workspace_size, filters.options().dtype(at::kByte)); + + auto features_impl = torch_mlu::getMluTensorImpl(features_contiguous); + auto filters_impl = torch_mlu::getMluTensorImpl(filters_contiguous); + auto output_grad_impl = torch_mlu::getMluTensorImpl(output_grad_contiguous); + auto indice_pairs_impl = torch_mlu::getMluTensorImpl(indice_pairs_contiguous); + auto indice_convbpdata_workspace_impl = + torch_mlu::getMluTensorImpl(indice_convbpdata_workspace); + auto indice_convbpfilter_workspace_impl = + torch_mlu::getMluTensorImpl(indice_convbpfilter_workspace); + + auto features_ptr = features_impl->cnnlMalloc(); + auto filters_ptr = filters_impl->cnnlMalloc(); + auto output_grad_ptr = output_grad_impl->cnnlMalloc(); + auto indice_pairs_ptr = indice_pairs_impl->cnnlMalloc(); + auto indice_convbpdata_workspace_ptr = + indice_convbpdata_workspace_impl->cnnlMalloc(); + auto indice_convbpfilter_workspace_ptr = + indice_convbpfilter_workspace_impl->cnnlMalloc(); + + // outputs + auto input_grad_impl = torch_mlu::getMluTensorImpl(input_grad); + auto input_grad_ptr = input_grad_impl->cnnlMalloc(); + auto filters_grad_impl = torch_mlu::getMluTensorImpl(filters_grad); + auto filters_grad_ptr = filters_grad_impl->cnnlMalloc(); + + mluOpIndiceConvolutionBackwardData( + handle, output_grad_desc.desc(), output_grad_ptr, filters_desc.desc(), + filters_ptr, indice_pairs_desc.desc(), indice_pairs_ptr, indice_num, + _inverse, _subM, indice_convbpdata_workspace_ptr, data_workspace_size, + input_grad_desc.desc(), input_grad_ptr); + + mluOpIndiceConvolutionBackwardFilter( + handle, features_desc.desc(), features_ptr, output_grad_desc.desc(), + output_grad_ptr, indice_pairs_desc.desc(), indice_pairs_ptr, indice_num, + _inverse, _subM, indice_convbpfilter_workspace_ptr, + filters_workspace_size, filters_grad_desc.desc(), filters_grad_ptr); + + std::vector result; + result.push_back(input_grad); + result.push_back(filters_grad); + return result; +} + +torch::Tensor indice_conv_forward_mlu(torch::Tensor features, + torch::Tensor filters, + torch::Tensor indicePairs, + torch::Tensor indiceNum, + int64_t numActOut, int64_t _inverse, + int64_t _subM) { + return IndiceConvForwardMLUKernelLauncher( + features, filters, indicePairs, indiceNum, numActOut, _inverse, _subM); +} + +std::vector indice_conv_backward_mlu( + torch::Tensor features, torch::Tensor filters, torch::Tensor outGrad, + torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t _inverse, + int64_t _subM) { + return IndiceConvBackwardMLUKernelLauncher( + features, filters, outGrad, indicePairs, indiceNum, _inverse, _subM); +} + +torch::Tensor indice_conv_forward_impl(torch::Tensor features, + torch::Tensor filters, + torch::Tensor indicePairs, + torch::Tensor indiceNum, + int64_t numActOut, int64_t _inverse, + int64_t _subM); + +std::vector indice_conv_backward_impl( + torch::Tensor features, torch::Tensor filters, torch::Tensor outGrad, + torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t _inverse, + int64_t _subM); + +REGISTER_DEVICE_IMPL(indice_conv_forward_impl, MLU, indice_conv_forward_mlu); +REGISTER_DEVICE_IMPL(indice_conv_backward_impl, MLU, indice_conv_backward_mlu); + +template std::vector GetIndicePairsForwardMLUKernelLauncher<2>( + torch::Tensor indices, int64_t batchSize, + std::vector outSpatialShape, std::vector spatialShape, + std::vector kernelSize, std::vector stride, + std::vector padding, std::vector dilation, + std::vector outPadding, int64_t _subM, int64_t _transpose); + +template std::vector GetIndicePairsForwardMLUKernelLauncher<3>( + torch::Tensor indices, int64_t batchSize, + std::vector outSpatialShape, std::vector spatialShape, + std::vector kernelSize, std::vector stride, + std::vector padding, std::vector dilation, + std::vector outPadding, int64_t _subM, int64_t _transpose); + +template std::vector GetIndicePairsForwardMLUKernelLauncher<4>( + torch::Tensor indices, int64_t batchSize, + std::vector outSpatialShape, std::vector spatialShape, + std::vector kernelSize, std::vector stride, + std::vector padding, std::vector dilation, + std::vector outPadding, int64_t _subM, int64_t _transpose); \ No newline at end of file diff --git a/mmcv/ops/csrc/pytorch/spconv_ops.cpp b/mmcv/ops/csrc/pytorch/spconv_ops.cpp index 09c8110ad8..5d1a62a065 100644 --- a/mmcv/ops/csrc/pytorch/spconv_ops.cpp +++ b/mmcv/ops/csrc/pytorch/spconv_ops.cpp @@ -35,6 +35,26 @@ std::vector get_indice_pairs_forward_cuda( padding, dilation, outPadding, _subM, _transpose); }; +template +std::vector GetIndicePairsForwardMLUKernelLauncher( + torch::Tensor indices, int64_t batchSize, + std::vector outSpatialShape, std::vector spatialShape, + std::vector kernelSize, std::vector stride, + std::vector padding, std::vector dilation, + std::vector outPadding, int64_t _subM, int64_t _transpose); + +template +std::vector get_indice_pairs_forward_mlu( + torch::Tensor indices, int64_t batchSize, + std::vector outSpatialShape, std::vector spatialShape, + std::vector kernelSize, std::vector stride, + std::vector padding, std::vector dilation, + std::vector outPadding, int64_t _subM, int64_t _transpose) { + return GetIndicePairsForwardMLUKernelLauncher( + indices, batchSize, outSpatialShape, spatialShape, kernelSize, stride, + padding, dilation, outPadding, _subM, _transpose); +} + template std::vector GetIndicePairsBackwardCUDAKernelLauncher( torch::Tensor indices, torch::Tensor gridOut, int64_t batchSize, @@ -71,6 +91,12 @@ std::vector get_indice_pairs_forward( padding, dilation, outPadding, _subM, _transpose); #else AT_ERROR("get_indice_pairs is not compiled with GPU support"); +#endif +#ifdef MMCV_WITH_MLU + } else if (indices.device().type() == at::kMLU) { + return get_indice_pairs_forward_mlu( + indices, batchSize, outSpatialShape, spatialShape, kernelSize, stride, + padding, dilation, outPadding, _subM, _transpose); #endif } else { AT_ERROR("get_indice_pairs is not implemented on CPU"); diff --git a/setup.py b/setup.py index 5a7afe86f7..e1a54c1240 100644 --- a/setup.py +++ b/setup.py @@ -388,7 +388,7 @@ def get_mluops_version(file_path): glob.glob( './mlu-ops/bangc-ops/kernels/**/*.mlu', recursive=True) extra_objects = glob.glob( - './mlu-ops/bangc-ops/kernels/*/x86_64/*.o') + './mlu-ops/bangc-ops/kernels/kernel_wrapper/**/*.a') extension = MLUExtension include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common')) include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/mlu')) diff --git a/tests/test_ops/test_spconv.py b/tests/test_ops/test_spconv.py index 098ff2189a..17ca5678ed 100644 --- a/tests/test_ops/test_spconv.py +++ b/tests/test_ops/test_spconv.py @@ -10,6 +10,8 @@ if torch.__version__ == 'parrots': pytest.skip('not supported in parrots now', allow_module_level=True) +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE + def make_sparse_convmodule(in_channels, out_channels, @@ -76,21 +78,29 @@ def make_sparse_convmodule(in_channels, return layers -@pytest.mark.skipif( - not torch.cuda.is_available(), reason='requires CUDA support') -def test_make_sparse_convmodule(): +@pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'mlu', + marks=pytest.mark.skipif( + not IS_MLU_AVAILABLE, reason='requires MLU support')) +]) +def test_make_sparse_convmodule(device): torch.cuda.empty_cache() voxel_features = torch.tensor([[6.56126, 0.9648336, -1.7339306, 0.315], [6.8162713, -2.480431, -1.3616394, 0.36], [11.643568, -4.744306, -1.3580885, 0.16], [23.482342, 6.5036807, 0.5806964, 0.35]], dtype=torch.float32, - device='cuda') # n, point_features + device=device) # n, point_features coordinates = torch.tensor( [[0, 12, 819, 131], [0, 16, 750, 136], [1, 16, 705, 232], [1, 35, 930, 469]], dtype=torch.int32, - device='cuda') # n, 4(batch, ind_x, ind_y, ind_z) + device=device) # n, 4(batch, ind_x, ind_y, ind_z) # test input_sp_tensor = SparseConvTensor(voxel_features, coordinates, @@ -105,7 +115,7 @@ def test_make_sparse_convmodule(): padding=0, conv_type='SubMConv3d', norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01), - order=('conv', 'norm', 'act')).cuda() + order=('conv', 'norm', 'act')).to(device) assert isinstance(sparse_block0[0], SubMConv3d) assert sparse_block0[0].in_channels == 4 assert sparse_block0[0].out_channels == 16 @@ -118,16 +128,18 @@ def test_make_sparse_convmodule(): out_features = sparse_block0(input_sp_tensor) assert out_features.features.shape == torch.Size([4, 16]) - sparse_block1 = make_sparse_convmodule( - 4, - 16, - 3, - 'test1', - stride=1, - padding=0, - conv_type='SparseInverseConv3d', - norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01), - order=('norm', 'act', 'conv')).cuda() - assert isinstance(sparse_block1[0], torch.nn.BatchNorm1d) - assert isinstance(sparse_block1[1], torch.nn.ReLU) - assert isinstance(sparse_block1[2], SparseInverseConv3d) + # device == mlu: not support inverse==1 yet + if device != 'mlu': + sparse_block1 = make_sparse_convmodule( + 4, + 16, + 3, + 'test1', + stride=1, + padding=0, + conv_type='SparseInverseConv3d', + norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01), + order=('norm', 'act', 'conv')).to(device) + assert isinstance(sparse_block1[2], SparseInverseConv3d) + assert isinstance(sparse_block1[0], torch.nn.BatchNorm1d) + assert isinstance(sparse_block1[1], torch.nn.ReLU) From ecf0d0cb339eddd2d6b85ff43b33b319456a978d Mon Sep 17 00:00:00 2001 From: duzekun Date: Wed, 15 Mar 2023 14:35:50 +0800 Subject: [PATCH 2/6] [Feature] update cpp code style --- mmcv/ops/csrc/pytorch/mlu/sparse_conv_mlu.cpp | 118 ++++++++++-------- mmcv/ops/csrc/pytorch/spconv_ops.cpp | 2 +- 2 files changed, 65 insertions(+), 55 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/mlu/sparse_conv_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/sparse_conv_mlu.cpp index f32a9d6ffa..18b312025b 100644 --- a/mmcv/ops/csrc/pytorch/mlu/sparse_conv_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/sparse_conv_mlu.cpp @@ -23,7 +23,6 @@ std::vector GetIndicePairsForwardMLUKernelLauncher( std::vector kernelSize, std::vector stride, std::vector padding, std::vector dilation, std::vector outPadding, int64_t _subM, int64_t _transpose) { - // The following code is copied from // mmcv/ops/csrc/pytorch/cuda/spconv_ops_cuda.cu to ensure the output is // available for network train. The outputs of this function have correct @@ -45,12 +44,13 @@ std::vector GetIndicePairsForwardMLUKernelLauncher( } torch::Tensor indicePairs = at::full({kernelVolume, 2, numAct}, -1, indices.options().dtype(at::kInt)); - torch::Tensor indiceNum = at::zeros({kernelVolume}, - indices.options().dtype(at::kInt)); - int out_size = sub_m == 1 ? - numAct : std::min(numAct * kernelVolume, batch * outputVolume); - torch::Tensor out_indices = at::zeros({out_size, coorDim + 1}, - indices.options().dtype(at::kInt)); + torch::Tensor indiceNum = + at::zeros({kernelVolume}, indices.options().dtype(at::kInt)); + int out_size = sub_m == 1 + ? numAct + : std::min(numAct * kernelVolume, batch * outputVolume); + torch::Tensor out_indices = + at::zeros({out_size, coorDim + 1}, indices.options().dtype(at::kInt)); auto indices_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( indices, at::MemoryFormat::Contiguous); auto indicePairs_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( @@ -67,56 +67,58 @@ std::vector GetIndicePairsForwardMLUKernelLauncher( std::vector stride32; std::vector dilation32; for (int i = 0; i < NDim; i++) { - input_space.push_back(spatialShape[i]); - filter_space.push_back(kernelSize[i]); - output_space.push_back(outSpatialShape[i]); - padding32.push_back(padding[i]); - stride32.push_back(stride[i]); - dilation32.push_back(dilation[i]); + input_space.push_back(spatialShape[i]); + filter_space.push_back(kernelSize[i]); + output_space.push_back(outSpatialShape[i]); + padding32.push_back(padding[i]); + stride32.push_back(stride[i]); + dilation32.push_back(dilation[i]); } - MluOpTensorDescriptor indices_desc, out_indices_desc, - indicePairs_desc, indiceNum_desc; + MluOpTensorDescriptor indices_desc, out_indices_desc, indicePairs_desc, + indiceNum_desc; indices_desc.set(indices_contiguous); indicePairs_desc.set(indicePairs_contiguous); indiceNum_desc.set(indiceNum_contiguous); out_indices_desc.set(out_indices_contiguous); { - mluOpTensorLayout_t layout = MLUOP_LAYOUT_ARRAY; - mluOpDataType_t dtype = MLUOP_DTYPE_INT32; - std::vector dims; - dims = {numAct, coorDim + 1}; - mluOpSetTensorDescriptor(indices_desc.desc(), layout, dtype, - dims.size(), dims.data()); - dims = {kernelVolume, 2, numAct}; - mluOpSetTensorDescriptor(indicePairs_desc.desc(), layout, dtype, - dims.size(), dims.data()); - dims = {kernelVolume}; - mluOpSetTensorDescriptor(indiceNum_desc.desc(), layout, dtype, + mluOpTensorLayout_t layout = MLUOP_LAYOUT_ARRAY; + mluOpDataType_t dtype = MLUOP_DTYPE_INT32; + std::vector dims; + dims = {numAct, coorDim + 1}; + mluOpSetTensorDescriptor(indices_desc.desc(), layout, dtype, dims.size(), + dims.data()); + dims = {kernelVolume, 2, numAct}; + mluOpSetTensorDescriptor(indicePairs_desc.desc(), layout, dtype, dims.size(), dims.data()); - dims = {out_size, coorDim + 1}; - mluOpSetTensorDescriptor(out_indices_desc.desc(), layout, dtype, + dims = {kernelVolume}; + mluOpSetTensorDescriptor(indiceNum_desc.desc(), layout, dtype, dims.size(), + dims.data()); + dims = {out_size, coorDim + 1}; + mluOpSetTensorDescriptor(out_indices_desc.desc(), layout, dtype, dims.size(), dims.data()); } mluOpSparseConvolutionDescriptor_t sparse_conv_desc; mluOpCreateSparseConvolutionDescriptor(&sparse_conv_desc); - mluOpSetSparseConvolutionDescriptor(sparse_conv_desc, NDim + 2, batch, - padding32.data(), stride32.data(), dilation32.data(), input_space.data(), - filter_space.data(), output_space.data(), sub_m, transpose, 0); + mluOpSetSparseConvolutionDescriptor( + sparse_conv_desc, NDim + 2, batch, padding32.data(), stride32.data(), + dilation32.data(), input_space.data(), filter_space.data(), + output_space.data(), sub_m, transpose, 0); auto handle = mluOpGetCurrentHandle(); size_t workspace_size = 0; - mluOpGetIndicePairsWorkspaceSize(handle, sparse_conv_desc, indices_desc.desc(), - indicePairs_desc.desc(), out_indices_desc.desc(), indiceNum_desc.desc(), - &workspace_size); - auto indice_workspace_size = at::empty(workspace_size, - indices.options().dtype(at::kByte)); + mluOpGetIndicePairsWorkspaceSize( + handle, sparse_conv_desc, indices_desc.desc(), indicePairs_desc.desc(), + out_indices_desc.desc(), indiceNum_desc.desc(), &workspace_size); + auto indice_workspace_size = + at::empty(workspace_size, indices.options().dtype(at::kByte)); auto indices_impl = torch_mlu::getMluTensorImpl(indices_contiguous); auto out_indices_impl = torch_mlu::getMluTensorImpl(out_indices_contiguous); auto indicePairs_impl = torch_mlu::getMluTensorImpl(indicePairs_contiguous); auto indiceNum_impl = torch_mlu::getMluTensorImpl(indiceNum_contiguous); - auto indice_workspace_impl = torch_mlu::getMluTensorImpl(indice_workspace_size); + auto indice_workspace_impl = + torch_mlu::getMluTensorImpl(indice_workspace_size); auto indices_ptr = indices_impl->cnnlMalloc(); auto out_indices_ptr = out_indices_impl->cnnlMalloc(); @@ -124,17 +126,18 @@ std::vector GetIndicePairsForwardMLUKernelLauncher( auto indiceNum_ptr = indiceNum_impl->cnnlMalloc(); auto indice_workspace_ptr = indice_workspace_impl->cnnlMalloc(); - mluOpGetIndicePairs(handle, sparse_conv_desc, indices_desc.desc(), indices_ptr, - indice_workspace_ptr, workspace_size, indicePairs_desc.desc(), - indicePairs_ptr, out_indices_desc.desc(), out_indices_ptr, - indiceNum_desc.desc(), indiceNum_ptr); + mluOpGetIndicePairs(handle, sparse_conv_desc, indices_desc.desc(), + indices_ptr, indice_workspace_ptr, workspace_size, + indicePairs_desc.desc(), indicePairs_ptr, + out_indices_desc.desc(), out_indices_ptr, + indiceNum_desc.desc(), indiceNum_ptr); int num_act_out = 0; mluOpGetSparseConvolutionNumActOut(sparse_conv_desc, &num_act_out); mluOpDestroySparseConvolutionDescriptor(sparse_conv_desc); if (!sub_m) { - return {out_indices.slice(0, 0, num_act_out), indicePairs, indiceNum}; + return {out_indices.slice(0, 0, num_act_out), indicePairs, indiceNum}; } else { - return {indices, indicePairs, indiceNum}; + return {indices, indicePairs, indiceNum}; } } @@ -180,26 +183,32 @@ torch::Tensor IndiceConvForwardMLUKernelLauncher( // features_desc mluOpGetTensorDescriptor(features_desc.desc(), &layout, &dtype, &dim, dims); - mluOpSetTensorDescriptor(features_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, dim, dims); + mluOpSetTensorDescriptor(features_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, + dim, dims); // filters_desc mluOpGetTensorDescriptor(filters_desc.desc(), &layout, &dtype, &dim, dims); - mluOpSetTensorDescriptor(filters_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, dim, dims); + mluOpSetTensorDescriptor(filters_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, + dim, dims); // indice_pairs_desc - mluOpGetTensorDescriptor(indice_pairs_desc.desc(), &layout, &dtype, &dim, dims); - mluOpSetTensorDescriptor(indice_pairs_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, dim, dims); + mluOpGetTensorDescriptor(indice_pairs_desc.desc(), &layout, &dtype, &dim, + dims); + mluOpSetTensorDescriptor(indice_pairs_desc.desc(), MLUOP_LAYOUT_ARRAY, + dtype, dim, dims); // output_desc mluOpGetTensorDescriptor(output_desc.desc(), &layout, &dtype, &dim, dims); - mluOpSetTensorDescriptor(output_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, dim, dims); + mluOpSetTensorDescriptor(output_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, dim, + dims); } auto handle = mluOpGetCurrentHandle(); size_t workspace_size = 0; mluOpGetIndiceConvolutionForwardWorkspaceSize( - handle, features_desc.desc(), filters_desc.desc(), indice_pairs_desc.desc(), - output_desc.desc(), indice_num, numActOut, _inverse, _subM, &workspace_size); + handle, features_desc.desc(), filters_desc.desc(), + indice_pairs_desc.desc(), output_desc.desc(), indice_num, numActOut, + _inverse, _subM, &workspace_size); auto workspace = at::empty(workspace_size, features.options().dtype(at::kByte)); @@ -216,11 +225,12 @@ torch::Tensor IndiceConvForwardMLUKernelLauncher( // outputs auto output_impl = torch_mlu::getMluTensorImpl(output); - auto output_ptr = output_impl->cnnlMalloc(); + auto output_ptr = output_impl->cnnlMalloc(); mluOpIndiceConvolutionForward( - handle, features_desc.desc(), features_ptr, filters_desc.desc(), filters_ptr, - indice_pairs_desc.desc(), indice_pairs_ptr, indice_num, numActOut, _inverse, _subM, - workspace_ptr, workspace_size, output_desc.desc(), output_ptr); + handle, features_desc.desc(), features_ptr, filters_desc.desc(), + filters_ptr, indice_pairs_desc.desc(), indice_pairs_ptr, indice_num, + numActOut, _inverse, _subM, workspace_ptr, workspace_size, + output_desc.desc(), output_ptr); return output; } diff --git a/mmcv/ops/csrc/pytorch/spconv_ops.cpp b/mmcv/ops/csrc/pytorch/spconv_ops.cpp index 5d1a62a065..723c6c7b90 100644 --- a/mmcv/ops/csrc/pytorch/spconv_ops.cpp +++ b/mmcv/ops/csrc/pytorch/spconv_ops.cpp @@ -92,7 +92,7 @@ std::vector get_indice_pairs_forward( #else AT_ERROR("get_indice_pairs is not compiled with GPU support"); #endif -#ifdef MMCV_WITH_MLU +#ifdef MMCV_WITH_MLU } else if (indices.device().type() == at::kMLU) { return get_indice_pairs_forward_mlu( indices, batchSize, outSpatialShape, spatialShape, kernelSize, stride, From 6b3421b51921d943fb4155830fa65a0fad16447f Mon Sep 17 00:00:00 2001 From: duzekun Date: Wed, 15 Mar 2023 14:42:14 +0800 Subject: [PATCH 3/6] end-of-file --- mmcv/ops/csrc/pytorch/mlu/sparse_conv_mlu.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmcv/ops/csrc/pytorch/mlu/sparse_conv_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/sparse_conv_mlu.cpp index 18b312025b..c025a84d21 100644 --- a/mmcv/ops/csrc/pytorch/mlu/sparse_conv_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/sparse_conv_mlu.cpp @@ -442,4 +442,4 @@ template std::vector GetIndicePairsForwardMLUKernelLauncher<4>( std::vector outSpatialShape, std::vector spatialShape, std::vector kernelSize, std::vector stride, std::vector padding, std::vector dilation, - std::vector outPadding, int64_t _subM, int64_t _transpose); \ No newline at end of file + std::vector outPadding, int64_t _subM, int64_t _transpose); From b4af962fe66eb4f6d81a8c121e1f55c8fe9d5a37 Mon Sep 17 00:00:00 2001 From: budefei Date: Fri, 17 Mar 2023 12:53:08 +0800 Subject: [PATCH 4/6] delete libext.a --- mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h b/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h index 678dc52029..38805c0dec 100644 --- a/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h +++ b/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h @@ -19,7 +19,7 @@ #define MLUOP_MAJOR 0 #define MLUOP_MINOR 5 -#define MLUOP_PATCHLEVEL 0 +#define MLUOP_PATCHLEVEL 302 mluOpDataType_t getMluOpDataType(const caffe2::TypeMeta& data_type); mluOpTensorLayout_t getMluOpSuggestLayout(const at::Tensor& input); diff --git a/setup.py b/setup.py index e1a54c1240..176db6b870 100644 --- a/setup.py +++ b/setup.py @@ -388,7 +388,7 @@ def get_mluops_version(file_path): glob.glob( './mlu-ops/bangc-ops/kernels/**/*.mlu', recursive=True) extra_objects = glob.glob( - './mlu-ops/bangc-ops/kernels/kernel_wrapper/**/*.a') + './mlu-ops/bangc-ops/kernels/kernel_wrapper/*.o') extension = MLUExtension include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common')) include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/mlu')) From ee88e2e450a60626e44f236b9e07fb06bd259eae Mon Sep 17 00:00:00 2001 From: budefei Date: Fri, 17 Mar 2023 14:57:48 +0800 Subject: [PATCH 5/6] code style --- mmcv/ops/csrc/pytorch/mlu/sparse_conv_mlu.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/mlu/sparse_conv_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/sparse_conv_mlu.cpp index c025a84d21..165aae1715 100644 --- a/mmcv/ops/csrc/pytorch/mlu/sparse_conv_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/sparse_conv_mlu.cpp @@ -9,13 +9,14 @@ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ +#include + +#include + #include "mlu_common_helper.h" #include "pytorch_device_registry.hpp" #include "pytorch_mlu_helper.hpp" -#include -#include - template std::vector GetIndicePairsForwardMLUKernelLauncher( torch::Tensor indices, int64_t batchSize, From e5384ca630cca6d905e0d727bca54c3daf0714b6 Mon Sep 17 00:00:00 2001 From: duzekun Date: Mon, 20 Mar 2023 19:15:14 +0800 Subject: [PATCH 6/6] update ops.md --- docs/en/understand_mmcv/ops.md | 2 +- docs/zh_cn/understand_mmcv/ops.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index 6c5f760bbe..95cf94de5b 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -52,7 +52,7 @@ We implement common ops used in detection, segmentation, etc. | SigmoidFocalLoss | | √ | √ | | √ | | SoftmaxFocalLoss | | √ | | | √ | | SoftNMS | | √ | | | | -| Sparse Convolution | | √ | | | | +| Sparse Convolution | | √ | √ | | | | Synchronized BatchNorm | | √ | | | | | ThreeInterpolate | | √ | | | | | ThreeNN | | √ | √ | | | diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index 0fd0873a32..b4ace828d8 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -52,7 +52,7 @@ MMCV 提供了检测、分割等任务中常用的算子 | SigmoidFocalLoss | | √ | √ | | √ | | SoftmaxFocalLoss | | √ | | | √ | | SoftNMS | | √ | | | | -| Sparse Convolution | | √ | | | | +| Sparse Convolution | | √ | √ | | | | Synchronized BatchNorm | | √ | | | | | ThreeInterpolate | | √ | | | | | ThreeNN | | √ | √ | | |