Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add spconv ops from mmdet3d #1581

Merged
merged 37 commits into from
Feb 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
5115f07
add ops (spconv) of mmdet3d
DCNSW Nov 2, 2021
a4d2197
fix typo
DCNSW Nov 5, 2021
9c2c06a
refactor code
DCNSW Nov 20, 2021
55a0a15
resolve comments in #1452
WU-Wenhao Dec 13, 2021
fa91f84
resolve conflict
WU-Wenhao Dec 13, 2021
924058e
fix compile error
Dec 14, 2021
d4c31fa
Merge branch 'master' of github.com:open-mmlab/mmcv into add-mmdet3d-…
WU-Wenhao Dec 20, 2021
701b052
Merge branch 'add-mmdet3d-ops-spconv' of github.com:wHao-Wu/mmcv into…
WU-Wenhao Dec 20, 2021
05aac3e
fix bugs
WU-Wenhao Dec 20, 2021
1df25fc
fix bug
WU-Wenhao Dec 20, 2021
b431e45
transform from 'types.h' to 'extension.h'
WU-Wenhao Dec 21, 2021
04b2ae2
fix bug
WU-Wenhao Dec 21, 2021
a26e89e
transform from 'types.h' to 'extension.h' in parrots
WU-Wenhao Dec 21, 2021
da74a1d
add extension.h in pybind.cpp
WU-Wenhao Dec 21, 2021
e0ffb74
add unittest
WU-Wenhao Dec 21, 2021
d054696
Recover code
WU-Wenhao Dec 21, 2021
638be36
(1) Remove prettyprint.h
WU-Wenhao Dec 27, 2021
28059cf
(1) rename from `cu.h` to `cuh`
WU-Wenhao Dec 28, 2021
5efa302
reorganize files
WU-Wenhao Dec 29, 2021
65b0208
Add docstring for sparse_functional.py
WU-Wenhao Jan 5, 2022
12e32a6
resolve conflict
WU-Wenhao Jan 5, 2022
02b335d
use dispatcher
WU-Wenhao Jan 12, 2022
53c4a7a
remove template
Jan 12, 2022
3cec16d
use dispatch in cuda ops
WU-Wenhao Jan 12, 2022
b8e43a0
resolve Segmentation fault
WU-Wenhao Jan 14, 2022
0b8ae6b
resolve conflict
WU-Wenhao Jan 14, 2022
60a7ea9
remove useless files
WU-Wenhao Jan 14, 2022
49cc59e
fix lint
WU-Wenhao Jan 19, 2022
7406a5d
fix lint
WU-Wenhao Jan 19, 2022
4766b20
fix lint
WU-Wenhao Jan 19, 2022
ec416f9
fix unittest in test_build_layers.py
WU-Wenhao Jan 28, 2022
4c3b0dd
add tensorview into include_dirs when compiling
WU-Wenhao Feb 15, 2022
bfb1676
recover all deleted files
WU-Wenhao Feb 17, 2022
d56f750
fix lint and comments
WU-Wenhao Feb 17, 2022
11dc166
recover setup.py
WU-Wenhao Feb 17, 2022
037f4d2
replace tv::GPU as tv::TorchGPU & support device guard
WU-Wenhao Feb 18, 2022
03e4532
fix lint
WU-Wenhao Feb 18, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/en/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
- SigmoidFocalLoss
- SoftmaxFocalLoss
- SoftNMS
- Sparse Convolution
- Synchronized BatchNorm
- Voxelization
- ThreeInterpolate
Expand Down
1 change: 1 addition & 0 deletions docs/zh_cn/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
- SigmoidFocalLoss
- SoftmaxFocalLoss
- SoftNMS
- Sparse Convolution
- Synchronized BatchNorm
- Voxelization
- ThreeInterpolate
Expand Down
10 changes: 10 additions & 0 deletions mmcv/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@
from .rotated_feature_align import rotated_feature_align
from .saconv import SAConv2d
from .scatter_points import DynamicScatter, dynamic_scatter
from .sparse_conv import (SparseConv2d, SparseConv3d, SparseConvTranspose2d,
SparseConvTranspose3d, SparseInverseConv2d,
SparseInverseConv3d, SubMConv2d, SubMConv3d)
from .sparse_modules import SparseModule, SparseSequential
from .sparse_pool import SparseMaxPool2d, SparseMaxPool3d
from .sparse_structure import SparseConvTensor, scatter_nd
from .sync_bn import SyncBatchNorm
from .three_interpolate import three_interpolate
from .three_nn import three_nn
Expand Down Expand Up @@ -84,6 +90,10 @@
'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation',
'boxes_iou_bev', 'nms_bev', 'nms_normal_bev', 'Voxelization',
'voxelization', 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d',
'SparseConv2d', 'SparseConv3d', 'SparseConvTranspose2d',
'SparseConvTranspose3d', 'SparseInverseConv2d', 'SparseInverseConv3d',
'SubMConv2d', 'SubMConv3d', 'SparseModule', 'SparseSequential',
'SparseMaxPool2d', 'SparseMaxPool3d', 'SparseConvTensor', 'scatter_nd',
'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all',
'points_in_polygons', 'min_area_polygons', 'active_rotated_filter',
'convex_iou', 'convex_giou'
Expand Down
236 changes: 236 additions & 0 deletions mmcv/ops/csrc/common/cuda/spconv/indice.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
// Copyright 2019 Yan Yan
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef INDICE_CU_H_
#define INDICE_CU_H_
#include <utils/spconv/spconv/geometry.h>
#include <utils/spconv/tensorview/tensorview.h>

#include <utils/spconv/tensorview/helper_kernel.cuh>

template <typename Index, typename IndexGrid, unsigned NDim,
int KernelMaxVolume = 256>
__global__ void prepareIndicePairsKernel(
tv::TensorView<const Index> indicesIn, tv::TensorView<Index> indicesOut,
tv::TensorView<IndexGrid> gridsOut, tv::TensorView<Index> indicePairs,
tv::TensorView<Index> indiceNum, tv::TensorView<Index> indicePairUnique,
const tv::SimpleVector<Index, NDim> kernelSize,
const tv::SimpleVector<Index, NDim> stride,
const tv::SimpleVector<Index, NDim> padding,
const tv::SimpleVector<Index, NDim> dilation,
const tv::SimpleVector<Index, NDim> outSpatialShape) {
auto numActIn = indicesIn.dim(0);
Index spatialVolume = 1;
#pragma unroll
for (int i = 0; i < NDim; ++i) {
spatialVolume *= outSpatialShape[i];
}
Index kernelVolume = 1;
#pragma unroll
for (int i = 0; i < NDim; ++i) {
kernelVolume *= kernelSize[i];
}
Index numValidPoints = 0;
Index validPoints[KernelMaxVolume * (NDim + 1)];
Index *pointPtr = nullptr;
auto indicePairsDim2 = indicePairs.dim(2);
Index index;
for (int ix : tv::KernelLoopX<int>(numActIn)) {
numValidPoints = getValidOutPos<Index, NDim>(
indicesIn.data() + ix * (NDim + 1) + 1, kernelSize.data(),
stride.data(), padding.data(), dilation.data(), outSpatialShape.data(),
validPoints);
for (Index i = 0; i < numValidPoints; ++i) {
pointPtr = validPoints + i * (NDim + 1);
auto offset = pointPtr[NDim];
auto oldNum = atomicAdd(indiceNum.data() + offset, Index(1));
indicePairs(offset, 0, oldNum) = ix;
index = tv::rowArrayIdx<Index, NDim>(pointPtr, outSpatialShape.data()) +
spatialVolume * indicesIn(ix, 0);
indicePairs(offset, 1, oldNum) = index;
indicePairUnique[offset * indicePairsDim2 + oldNum] = index;
}
}
}

template <typename Index, typename IndexGrid, unsigned NDim,
int KernelMaxVolume = 256>
__global__ void prepareDeConvIndicePairsKernel(
tv::TensorView<const Index> indicesIn, tv::TensorView<Index> indicesOut,
tv::TensorView<IndexGrid> gridsOut, tv::TensorView<Index> indicePairs,
tv::TensorView<Index> indiceNum, tv::TensorView<Index> indicePairUnique,
const tv::SimpleVector<Index, NDim> kernelSize,
const tv::SimpleVector<Index, NDim> stride,
const tv::SimpleVector<Index, NDim> padding,
const tv::SimpleVector<Index, NDim> dilation,
const tv::SimpleVector<Index, NDim> outSpatialShape) {
auto numActIn = indicesIn.dim(0);
Index spatialVolume = 1;
#pragma unroll
for (int i = 0; i < NDim; ++i) {
spatialVolume *= outSpatialShape[i];
}
Index kernelVolume = 1;
#pragma unroll
for (int i = 0; i < NDim; ++i) {
kernelVolume *= kernelSize[i];
}
Index numValidPoints = 0;
Index validPoints[KernelMaxVolume * (NDim + 1)];
Index *pointPtr = nullptr;
auto indicePairsDim2 = indicePairs.dim(2);
Index index;
for (int ix : tv::KernelLoopX<int>(numActIn)) {
numValidPoints = getValidOutPosTranspose<Index, NDim>(
indicesIn.data() + ix * (NDim + 1) + 1, kernelSize.data(),
stride.data(), padding.data(), dilation.data(), outSpatialShape.data(),
validPoints);
for (Index i = 0; i < numValidPoints; ++i) {
pointPtr = validPoints + i * (NDim + 1);
auto offset = pointPtr[NDim];
auto oldNum = atomicAdd(indiceNum.data() + offset, Index(1));
indicePairs(offset, 0, oldNum) = ix;
index = tv::rowArrayIdx<Index, NDim>(pointPtr, outSpatialShape.data()) +
spatialVolume * indicesIn(ix, 0);
indicePairs(offset, 1, oldNum) = index;
indicePairUnique[offset * indicePairsDim2 + oldNum] = index;
}
}
}

template <typename Index, typename IndexGrid, unsigned NDim>
__global__ void assignGridAndIndiceOutKernel(
tv::TensorView<Index> indicesOut, tv::TensorView<IndexGrid> gridsOut,
int numAct, tv::TensorView<Index> indicePairs,
tv::TensorView<Index> indicePairUnique,
const tv::SimpleVector<Index, NDim> outSpatialShape, int batchSize) {
Index index;
auto indicesOutPtr = indicesOut.data();
for (int ix : tv::KernelLoopX<int>(numAct)) {
index = indicePairUnique[ix];
gridsOut[index] = ix;
index = tv::rowArrayIdxInv<Index, NDim>(
index, indicesOutPtr + ix * (NDim + 1) + 1, outSpatialShape.data());
indicesOut[ix * (NDim + 1)] = index % batchSize;
}
}

template <typename Index, typename IndexGrid, unsigned NDim>
__global__ void assignIndicePairsKernel(
tv::TensorView<Index> indicesOut, tv::TensorView<IndexGrid> gridsOut,
int numActIn, tv::TensorView<Index> indicePairs,
tv::TensorView<Index> indicePairUnique,
const tv::SimpleVector<Index, NDim> outSpatialShape) {
Index index;
int kernelVolume = indicePairs.dim(0);
for (int ix : tv::KernelLoopX<int>(numActIn)) {
for (int i = 0; i < kernelVolume; ++i) {
index = indicePairs(i, 1, ix);
if (index > -1) {
indicePairs(i, 1, ix) = gridsOut[index];
}
}
}
}

template <typename Index, typename IndexGrid, unsigned NDim>
__global__ void prepareSubMGridKernel(
tv::TensorView<const Index> indicesIn, tv::TensorView<IndexGrid> gridsOut,
const tv::SimpleVector<Index, NDim> outSpatialShape) {
auto numActIn = indicesIn.dim(0);
Index spatialVolume = 1;
#pragma unroll
for (int i = 0; i < NDim; ++i) {
spatialVolume *= outSpatialShape[i];
}
Index index = 0;
for (int ix : tv::KernelLoopX<int>(numActIn)) {
index = tv::rowArrayIdx<Index, NDim>(indicesIn.data() + ix * (NDim + 1) + 1,
outSpatialShape.data()) +
spatialVolume * indicesIn(ix, 0);
gridsOut[index] = ix;
}
}

template <typename Index, typename IndexGrid, unsigned NDim,
int KernelMaxVolume = 256>
__global__ void getSubMIndicePairsKernel(
tv::TensorView<const Index> indicesIn, tv::TensorView<IndexGrid> gridsOut,
tv::TensorView<Index> indicePairs, tv::TensorView<Index> indiceNum,
const tv::SimpleVector<Index, NDim> kernelSize,
const tv::SimpleVector<Index, NDim> stride,
const tv::SimpleVector<Index, NDim> padding,
const tv::SimpleVector<Index, NDim> dilation,
const tv::SimpleVector<Index, NDim> outSpatialShape) {
auto numActIn = indicesIn.dim(0);
Index spatialVolume = 1;
#pragma unroll
for (int i = 0; i < NDim; ++i) {
spatialVolume *= outSpatialShape[i];
}
Index numValidPoints = 0;
Index validPoints[KernelMaxVolume * (NDim + 1)];
Index *pointPtr = nullptr;
Index index = 0;
for (int ix : tv::KernelLoopX<int>(numActIn)) {
numValidPoints = getValidOutPos<Index, NDim>(
indicesIn.data() + ix * (NDim + 1) + 1, kernelSize.data(),
stride.data(), padding.data(), dilation.data(), outSpatialShape.data(),
validPoints);
for (int i = 0; i < numValidPoints; ++i) {
pointPtr = validPoints + i * (NDim + 1);
auto offset = pointPtr[NDim];
index = tv::rowArrayIdx<Index, NDim>(pointPtr, outSpatialShape.data()) +
spatialVolume * indicesIn(ix, 0);
if (gridsOut[index] > -1) {
auto oldNum = atomicAdd(indiceNum.data() + offset, Index(1));
indicePairs(offset, 1, oldNum) = gridsOut[index];
indicePairs(offset, 0, oldNum) = ix;
}
}
}
}

template <typename Index, typename IndexGrid, unsigned NDim>
__global__ void resetGridKernel(const Index *indicePairUnique,
tv::TensorView<IndexGrid> gridsOut,
int numAct) {
for (int ix : tv::KernelLoopX<int>(numAct)) {
gridsOut[indicePairUnique[ix]] = -1;
}
}

template <typename Index, typename IndexGrid, unsigned NDim>
__global__ void resetGridSubMKernel(
const Index *indices, tv::TensorView<IndexGrid> gridsOut,
const tv::SimpleVector<Index, NDim> outSpatialShape, int numAct) {
int outSpatialShapeReg[NDim];
for (int i = 0; i < NDim; ++i) {
outSpatialShapeReg[i] = outSpatialShape[i];
}
Index spatialVolume = 1;
auto indsPtr = indices;
#pragma unroll
for (int i = 0; i < NDim; ++i) {
spatialVolume *= outSpatialShape[i];
}
Index index;
for (int ix : tv::KernelLoopX<int>(numAct)) {
indsPtr = indices + ix * (NDim + 1);
index = tv::rowArrayIdx<Index, NDim>(indsPtr + 1, outSpatialShapeReg);
gridsOut[index + spatialVolume * indsPtr[0]] = -1;
}
}

#endif
Loading