Skip to content

Commit

Permalink
tmp: mmcv add sparse_conv related code
Browse files Browse the repository at this point in the history
  • Loading branch information
duzekunKTH committed Feb 10, 2023
1 parent 5a61e53 commit a09c2b6
Show file tree
Hide file tree
Showing 4 changed files with 276 additions and 8 deletions.
198 changes: 198 additions & 0 deletions mmcv/ops/csrc/pytorch/mlu/sparse_conv_mlu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "mlu_common_helper.h"
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"

#include <torch/script.h>
#include <vector>

template <unsigned NDim>
std::vector<torch::Tensor> GetIndicePairsForwardMLUKernelLauncher(
torch::Tensor indices, int64_t batchSize,
std::vector<int64_t> outSpatialShape, std::vector<int64_t> spatialShape,
std::vector<int64_t> kernelSize, std::vector<int64_t> stride,
std::vector<int64_t> padding, std::vector<int64_t> dilation,
std::vector<int64_t> outPadding, int64_t _subM, int64_t _transpose) {
std::cout << "GetIndicePairsForwardMLUKernelLauncher start." << std::endl;

// The following code is copied from mmcv/ops/csrc/pytorch/cuda/spconv_ops_cuda.cu
// to ensure the output is available for network train.
// The outputs of this function have correct shape but wrong value.
auto numAct = indices.size(0);
auto kernelVolume = kernelSize[0];
for (int i = 1; i < kernelSize.size(); ++i) {
kernelVolume *= kernelSize[i];
}

auto outputVolume = outSpatialShape[0];
for (int i = 1; i < outSpatialShape.size(); ++i) {
outputVolume *= outSpatialShape[i];
}
torch::Tensor indicePairs =
at::full({kernelVolume, 2, numAct}, -1,
indices.options().dtype(at::kInt));
torch::Tensor indiceNum = at::zeros(
{kernelVolume}, indices.options().dtype(at::kInt));

std::cout << "GetIndicePairsForwardMLUKernelLauncher finish." << std::endl;
return {indices, indicePairs, indiceNum}; // differ from cuda code
}

torch::Tensor IndiceConvForwardMLUKernelLauncher(
torch::Tensor features, torch::Tensor filters, torch::Tensor indicePairs,
torch::Tensor indiceNum, int64_t numActOut, int64_t _inverse,
int64_t _subM) {
std::cout << "IndiceConvForwardMLUKernelLauncher start." << std::endl;
int C = filters.dim() == 4 ?
filters.size(3) : filters.size(4);
torch::Tensor output = at::zeros({numActOut, C}, features.options().dtype(at::kFloat));

std::cout << "IndiceConvForwardMLUKernelLauncher finish." << std::endl;
return output;
}

std::vector<torch::Tensor> IndiceConvBackwardMLUKernelLauncher(
torch::Tensor features, torch::Tensor filters, torch::Tensor outGrad,
torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t _inverse,
int64_t _subM) {
std::cout << "IndiceConvBackwardMLUKernelLauncher start." << std::endl;
/*
auto indice_num_cpu = indiceNum.to({torch::kCPU});
auto indice_num_cpu_64 = indice_num_cpu.data_ptr<int>();
int indice_num_len = indiceNum.numel();
int64_t indice_num[indice_num_len];
for (int i = 0; i < indice_num_len; ++i) {
// indice_num[i] = ((int64_t *)(indice_num_cpu_64.unsafeGetTensorImpl()->data()))[i];
indice_num[i] = (int64_t)(((int *)(indice_num_cpu_64))[i]);
}
auto input_grad_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
features, features.suggest_memory_format());
auto output_grad_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
outGrad, outGrad.suggest_memory_format());
auto filters_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
filters, filters.suggest_memory_format());
auto indice_pairs_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
indicePairs, indicePairs.suggest_memory_format());
MluOpTensorDescriptor output_grad_desc, filters_desc, indice_pairs_desc, input_grad_desc;
input_grad_desc.set(input_grad_contiguous);
output_grad_desc.set(output_grad_contiguous);
filters_desc.set(filters_contiguous);
indice_pairs_desc.set(indice_pairs_contiguous);
auto handle = mluOpGetCurrentHandle();
size_t workspace_size = 0;
mluOpGetIndiceConvolutionBackwardDataWorkspaceSize(
handle, output_grad_desc.desc(), filters_desc.desc(),
indice_pairs_desc.desc(), input_grad_desc.desc(),
indice_num, _inverse, &workspace_size);
printf("mluOpGetIndiceConvolutionBackwardDataWorkspaceSize %ld\n", workspace_size);
*/

// generate empty input_grad
torch::Tensor input_grad = at::zeros({features.size(0), features.size(1)}, features.options().dtype(at::kFloat));
torch::Tensor filters_grad;
if (filters.dim() == 4) {
int h = filters.size(0);
int w = filters.size(1);
int c = filters.size(2);
int n = filters.size(3);
filters_grad = at::zeros({h, w, c, n}, filters.options().dtype(at::kFloat));
} else if (filters.dim() == 5) {
int d = filters.size(0);
int h = filters.size(1);
int w = filters.size(2);
int c = filters.size(3);
int n = filters.size(4);
filters_grad = at::zeros({d, h, w, c, n}, filters.options().dtype(at::kFloat));
}
/*
auto indice_convbpdata_workspace = at::empty(workspace_size, features.options().dtype(at::kByte));
auto output_grad_impl = torch_mlu::getMluTensorImpl(output_grad_contiguous);
auto input_grad_impl = torch_mlu::getMluTensorImpl(input_grad_contiguous);
auto filters_impl = torch_mlu::getMluTensorImpl(filters_contiguous);
auto indice_pairs_impl = torch_mlu::getMluTensorImpl(indice_pairs_contiguous);
auto indice_convbpdata_workspace_impl = torch_mlu::getMluTensorImpl(indice_convbpdata_workspace);
auto output_grad_ptr = output_grad_impl->cnnlMalloc();
auto input_grad_ptr = input_grad_impl->cnnlMalloc();
auto filters_ptr = filters_impl->cnnlMalloc();
auto indice_pairs_ptr = indice_pairs_impl->cnnlMalloc();
auto indice_convbpdata_workspace_ptr = indice_convbpdata_workspace_impl->cnnlMalloc();
mluOpIndiceConvolutionBackwardData(
handle, output_grad_desc.desc(), output_grad_ptr, filters_desc.desc(), filters_ptr,
indice_pairs_desc.desc(), indice_pairs_ptr, indice_num, _inverse, _subM,
indice_convbpdata_workspace_ptr, workspace_size, input_grad_desc.desc(), input_grad_ptr);
*/
std::vector<torch::Tensor> result;
result.push_back(input_grad);
result.push_back(filters_grad);
std::cout << "IndiceConvBackwardMLUKernelLauncher finish." << std::endl;
return result;
}

torch::Tensor indice_conv_forward_mlu(
torch::Tensor features, torch::Tensor filters, torch::Tensor indicePairs,
torch::Tensor indiceNum, int64_t numActOut, int64_t _inverse,
int64_t _subM) {
return IndiceConvForwardMLUKernelLauncher(
features, filters, indicePairs, indiceNum, numActOut, _inverse, _subM);
}

std::vector<torch::Tensor> indice_conv_backward_mlu(
torch::Tensor features, torch::Tensor filters, torch::Tensor outGrad,
torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t _inverse,
int64_t _subM) {
return IndiceConvBackwardMLUKernelLauncher(
features, filters, outGrad, indicePairs, indiceNum, _inverse, _subM);
}

torch::Tensor indice_conv_forward_impl(torch::Tensor features,
torch::Tensor filters,
torch::Tensor indicePairs,
torch::Tensor indiceNum,
int64_t numActOut, int64_t _inverse,
int64_t _subM);

std::vector<torch::Tensor> indice_conv_backward_impl(
torch::Tensor features, torch::Tensor filters, torch::Tensor outGrad,
torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t _inverse,
int64_t _subM);

REGISTER_DEVICE_IMPL(indice_conv_forward_impl, MLU, indice_conv_forward_mlu);
REGISTER_DEVICE_IMPL(indice_conv_backward_impl, MLU, indice_conv_backward_mlu);


template std::vector<torch::Tensor> GetIndicePairsForwardMLUKernelLauncher<2>(
torch::Tensor indices, int64_t batchSize,
std::vector<int64_t> outSpatialShape, std::vector<int64_t> spatialShape,
std::vector<int64_t> kernelSize, std::vector<int64_t> stride,
std::vector<int64_t> padding, std::vector<int64_t> dilation,
std::vector<int64_t> outPadding, int64_t _subM, int64_t _transpose);

template std::vector<torch::Tensor> GetIndicePairsForwardMLUKernelLauncher<3>(
torch::Tensor indices, int64_t batchSize,
std::vector<int64_t> outSpatialShape, std::vector<int64_t> spatialShape,
std::vector<int64_t> kernelSize, std::vector<int64_t> stride,
std::vector<int64_t> padding, std::vector<int64_t> dilation,
std::vector<int64_t> outPadding, int64_t _subM, int64_t _transpose);

template std::vector<torch::Tensor> GetIndicePairsForwardMLUKernelLauncher<4>(
torch::Tensor indices, int64_t batchSize,
std::vector<int64_t> outSpatialShape, std::vector<int64_t> spatialShape,
std::vector<int64_t> kernelSize, std::vector<int64_t> stride,
std::vector<int64_t> padding, std::vector<int64_t> dilation,
std::vector<int64_t> outPadding, int64_t _subM, int64_t _transpose);
28 changes: 28 additions & 0 deletions mmcv/ops/csrc/pytorch/spconv_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,26 @@ std::vector<torch::Tensor> get_indice_pairs_forward_cuda(
padding, dilation, outPadding, _subM, _transpose);
};

template <unsigned NDim>
std::vector<torch::Tensor> GetIndicePairsForwardMLUKernelLauncher(
torch::Tensor indices, int64_t batchSize,
std::vector<int64_t> outSpatialShape, std::vector<int64_t> spatialShape,
std::vector<int64_t> kernelSize, std::vector<int64_t> stride,
std::vector<int64_t> padding, std::vector<int64_t> dilation,
std::vector<int64_t> outPadding, int64_t _subM, int64_t _transpose);

template <unsigned NDim>
std::vector<torch::Tensor> get_indice_pairs_forward_mlu(
torch::Tensor indices, int64_t batchSize,
std::vector<int64_t> outSpatialShape, std::vector<int64_t> spatialShape,
std::vector<int64_t> kernelSize, std::vector<int64_t> stride,
std::vector<int64_t> padding, std::vector<int64_t> dilation,
std::vector<int64_t> outPadding, int64_t _subM, int64_t _transpose) {
return GetIndicePairsForwardMLUKernelLauncher<NDim>(
indices, batchSize, outSpatialShape, spatialShape, kernelSize, stride,
padding, dilation, outPadding, _subM, _transpose);
}

template <unsigned NDim>
std::vector<torch::Tensor> GetIndicePairsBackwardCUDAKernelLauncher(
torch::Tensor indices, torch::Tensor gridOut, int64_t batchSize,
Expand Down Expand Up @@ -71,6 +91,14 @@ std::vector<torch::Tensor> get_indice_pairs_forward(
padding, dilation, outPadding, _subM, _transpose);
#else
AT_ERROR("get_indice_pairs is not compiled with GPU support");
#endif
} else if (indices.device().type() == at::kMLU) {
#ifdef MMCV_WITH_MLU
return get_indice_pairs_forward_mlu<NDim>(
indices, batchSize, outSpatialShape, spatialShape, kernelSize, stride,
padding, dilation, outPadding, _subM, _transpose);
#else
AT_ERROR("get_indice_pairs is not compiled with MLU support");
#endif
} else {
AT_ERROR("get_indice_pairs is not implemented on CPU");
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def get_mluops_version(file_path):
glob.glob(
'./mlu-ops/bangc-ops/kernels/**/*.mlu', recursive=True)
extra_objects = glob.glob(
'./mlu-ops/bangc-ops/kernels/*/x86_64/*.o')
'./mlu-ops/bangc-ops/kernels/kernel_wrapper/**/*.a')
extension = MLUExtension
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common'))
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/mlu'))
Expand Down
56 changes: 49 additions & 7 deletions tests/test_ops/test_spconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
if torch.__version__ == 'parrots':
pytest.skip('not supported in parrots now', allow_module_level=True)

from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE

def make_sparse_convmodule(in_channels,
out_channels,
Expand Down Expand Up @@ -76,21 +77,29 @@ def make_sparse_convmodule(in_channels,
return layers


@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_make_sparse_convmodule():
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
def test_make_sparse_convmodule(device):
torch.cuda.empty_cache()
voxel_features = torch.tensor([[6.56126, 0.9648336, -1.7339306, 0.315],
[6.8162713, -2.480431, -1.3616394, 0.36],
[11.643568, -4.744306, -1.3580885, 0.16],
[23.482342, 6.5036807, 0.5806964, 0.35]],
dtype=torch.float32,
device='cuda') # n, point_features
device=device) # n, point_features
coordinates = torch.tensor(
[[0, 12, 819, 131], [0, 16, 750, 136], [1, 16, 705, 232],
[1, 35, 930, 469]],
dtype=torch.int32,
device='cuda') # n, 4(batch, ind_x, ind_y, ind_z)
device=device) # n, 4(batch, ind_x, ind_y, ind_z)

# test
input_sp_tensor = SparseConvTensor(voxel_features, coordinates,
Expand All @@ -105,7 +114,7 @@ def test_make_sparse_convmodule():
padding=0,
conv_type='SubMConv3d',
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
order=('conv', 'norm', 'act')).cuda()
order=('conv', 'norm', 'act')).to(device)
assert isinstance(sparse_block0[0], SubMConv3d)
assert sparse_block0[0].in_channels == 4
assert sparse_block0[0].out_channels == 16
Expand All @@ -127,7 +136,40 @@ def test_make_sparse_convmodule():
padding=0,
conv_type='SparseInverseConv3d',
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
order=('norm', 'act', 'conv')).cuda()
order=('norm', 'act', 'conv')).to(device)
assert isinstance(sparse_block1[0], torch.nn.BatchNorm1d)
assert isinstance(sparse_block1[1], torch.nn.ReLU)
assert isinstance(sparse_block1[2], SparseInverseConv3d)

# test_make_sparse_convmodule('mlu')

def test_indice_conv_bp():
import numpy as np
from mmcv.utils import ext_loader
ext_module = ext_loader.load_ext('_ext',['indice_conv_backward'])

indice_pairs_num = [[[0,-1],
[0,-1]],
[[0,1],
[0,1]],
[[0,-1],
[0,-1]]]
feature = torch.tensor(np.ones((10,10))).mlu().float()
filters = torch.tensor(np.ones((3,1,1,10,10))).mlu().float()
outgrad = torch.tensor(np.ones((10,10))).mlu().float()
indice_pairs = torch.tensor(indice_pairs_num).mlu().int()
indice_num = torch.tensor([1,2,1]).mlu().int()
inverse = 0
sub_m = 1

print(indice_num)
ingrad, filter_grad = ext_module.indice_conv_backward(
feature,
filters,
outgrad,
indice_pairs,
indice_num,
int(inverse),
int(sub_m))

test_indice_conv_bp()

0 comments on commit a09c2b6

Please sign in to comment.