Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add interpolate ops from mmdet3d #1355

Merged
merged 9 commits into from
Oct 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
- SoftmaxFocalLoss
- SoftNMS
- Synchronized BatchNorm
- ThreeInterpolate
- ThreeNN
- Weight standardization
- Correlation
2 changes: 2 additions & 0 deletions docs_zh_CN/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
- SoftmaxFocalLoss
- SoftNMS
- Synchronized BatchNorm
- ThreeInterpolate
- ThreeNN
- Weight standardization
- Correlation
7 changes: 5 additions & 2 deletions mmcv/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
from .roi_pool import RoIPool, roi_pool
from .saconv import SAConv2d
from .sync_bn import SyncBatchNorm
from .three_interpolate import three_interpolate
from .three_nn import three_nn
from .tin_shift import TINShift, tin_shift
from .upfirdn2d import upfirdn2d

Expand All @@ -59,7 +61,8 @@
'SAConv2d', 'TINShift', 'tin_shift', 'box_iou_rotated', 'nms_rotated',
'knn', 'ball_query', 'upfirdn2d', 'FusedBiasLeakyReLU',
'fused_bias_leakyrelu', 'RoIAlignRotated', 'roi_align_rotated',
'pixel_group', 'contour_expand', 'MultiScaleDeformableAttention',
'BorderAlign', 'border_align', 'gather_points', 'furthest_point_sample',
'pixel_group', 'contour_expand', 'three_nn', 'three_interpolate',
'MultiScaleDeformableAttention', 'BorderAlign', 'border_align',
'gather_points', 'furthest_point_sample',
'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation'
]
61 changes: 61 additions & 0 deletions mmcv/ops/csrc/common/cuda/three_interpolate_cuda_kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef THREE_INTERPOLATE_CUDA_KERNEL_CUH
#define THREE_INTERPOLATE_CUDA_KERNEL_CUH

#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif

template <typename T>
__global__ void three_interpolate_forward_cuda_kernel(
int b, int c, int m, int n, const T *points, const int *__restrict__ idx,
const T *weight, T *out) {
// points: (B, C, M)
// idx: (B, N, 3)
// weight: (B, N, 3)
// output:
// out: (B, C, N)

int bs_idx = blockIdx.z;
int c_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;

if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;

weight += bs_idx * n * 3 + pt_idx * 3;
points += bs_idx * c * m + c_idx * m;
idx += bs_idx * n * 3 + pt_idx * 3;
out += bs_idx * c * n + c_idx * n;

out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] +
weight[2] * points[idx[2]];
}

template <typename T>
__global__ void three_interpolate_backward_cuda_kernel(
int b, int c, int n, int m, const T *grad_out, const int *__restrict__ idx,
const T *weight, T *grad_points) {
// grad_out: (B, C, N)
// weight: (B, N, 3)
// output:
// grad_points: (B, C, M)

int bs_idx = blockIdx.z;
int c_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;

if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;

grad_out += bs_idx * c * n + c_idx * n + pt_idx;
weight += bs_idx * n * 3 + pt_idx * 3;
grad_points += bs_idx * c * m + c_idx * m;
idx += bs_idx * n * 3 + pt_idx * 3;

atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]);
atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]);
atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]);
}

#endif // THREE_INTERPOLATE_CUDA_KERNEL_CUH
66 changes: 66 additions & 0 deletions mmcv/ops/csrc/common/cuda/three_nn_cuda_kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef THREE_NN_CUDA_KERNEL_CUH
#define THREE_NN_CUDA_KERNEL_CUH

#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif

template <typename T>
__global__ void three_nn_forward_cuda_kernel(int b, int n, int m,
const T *unknown, const T *known,
T *dist2, int *__restrict__ idx) {
// unknown: (B, N, 3)
// known: (B, M, 3)
// output:
// dist2: (B, N, 3)
// idx: (B, N, 3)

int bs_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= b || pt_idx >= n) return;

unknown += bs_idx * n * 3 + pt_idx * 3;
known += bs_idx * m * 3;
dist2 += bs_idx * n * 3 + pt_idx * 3;
idx += bs_idx * n * 3 + pt_idx * 3;

T ux = unknown[0];
T uy = unknown[1];
T uz = unknown[2];

double best1 = 1e40, best2 = 1e40, best3 = 1e40;
int besti1 = 0, besti2 = 0, besti3 = 0;
for (int k = 0; k < m; ++k) {
T x = known[k * 3 + 0];
T y = known[k * 3 + 1];
T z = known[k * 3 + 2];
T d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
if (d < best1) {
best3 = best2;
besti3 = besti2;
best2 = best1;
besti2 = besti1;
best1 = d;
besti1 = k;
} else if (d < best2) {
best3 = best2;
besti3 = besti2;
best2 = d;
besti2 = k;
} else if (d < best3) {
best3 = d;
besti3 = k;
}
}
dist2[0] = best1;
dist2[1] = best2;
dist2[2] = best3;
idx[0] = besti1;
idx[1] = besti2;
idx[2] = besti3;
}

#endif // THREE_NN_CUDA_KERNEL_CUH
66 changes: 66 additions & 0 deletions mmcv/ops/csrc/pytorch/cuda/three_interpolate_cuda.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/interpolate_gpu.cu

#include <math.h>
#include <stdio.h>
#include <stdlib.h>

#include "pytorch_cuda_helper.hpp"
#include "three_interpolate_cuda_kernel.cuh"

void ThreeInterpolateForwardCUDAKernelLauncher(int b, int c, int m, int n,
const Tensor points,
const Tensor idx,
const Tensor weight,
Tensor out) {
// points: (B, C, M)
// idx: (B, N, 3)
// weight: (B, N, 3)
// output:
// out: (B, C, N)

at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b);
dim3 threads(THREADS_PER_BLOCK);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
points.scalar_type(), "three_interpolate_forward_cuda_kernel", [&] {
three_interpolate_forward_cuda_kernel<scalar_t>
<<<blocks, threads, 0, stream>>>(
b, c, m, n, points.data_ptr<scalar_t>(), idx.data_ptr<int>(),
weight.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
});

AT_CUDA_CHECK(cudaGetLastError());
}

void ThreeInterpolateBackwardCUDAKernelLauncher(int b, int c, int n, int m,
const Tensor grad_out,
const Tensor idx,
const Tensor weight,
Tensor grad_points) {
// grad_out: (B, C, N)
// weight: (B, N, 3)
// output:
// grad_points: (B, C, M)

at::cuda::CUDAGuard device_guard(grad_out.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b);
dim3 threads(THREADS_PER_BLOCK);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_out.scalar_type(), "three_interpolate_backward_cuda_kernel", [&] {
three_interpolate_backward_cuda_kernel<scalar_t>
<<<blocks, threads, 0, stream>>>(
b, c, n, m, grad_out.data_ptr<scalar_t>(), idx.data_ptr<int>(),
weight.data_ptr<scalar_t>(), grad_points.data_ptr<scalar_t>());
});

AT_CUDA_CHECK(cudaGetLastError());
}
35 changes: 35 additions & 0 deletions mmcv/ops/csrc/pytorch/cuda/three_nn_cuda.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/interpolate_gpu.cu

#include <math.h>
#include <stdio.h>
#include <stdlib.h>

#include "pytorch_cuda_helper.hpp"
#include "three_nn_cuda_kernel.cuh"

void ThreeNNForwardCUDAKernelLauncher(int b, int n, int m, const Tensor unknown,
const Tensor known, Tensor dist2,
Tensor idx) {
// unknown: (B, N, 3)
// known: (B, M, 3)
// output:
// dist2: (B, N, 3)
// idx: (B, N, 3)

at::cuda::CUDAGuard device_guard(unknown.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), b);
dim3 threads(THREADS_PER_BLOCK);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
unknown.scalar_type(), "three_nn_forward_cuda_kernel", [&] {
three_nn_forward_cuda_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
b, n, m, unknown.data_ptr<scalar_t>(), known.data_ptr<scalar_t>(),
dist2.data_ptr<scalar_t>(), idx.data_ptr<int>());
});

AT_CUDA_CHECK(cudaGetLastError());
}
25 changes: 25 additions & 0 deletions mmcv/ops/csrc/pytorch/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,19 @@ void softmax_focal_loss_backward(Tensor input, Tensor target, Tensor weight,
Tensor buff, Tensor grad_input, float gamma,
float alpha);

void three_interpolate_forward(int b, int c, int m, int n, Tensor points_tensor,
Tensor idx_tensor, Tensor weight_tensor,
Tensor out_tensor);

void three_interpolate_backward(int b, int c, int n, int m,
Tensor grad_out_tensor, Tensor idx_tensor,
Tensor weight_tensor,
Tensor grad_points_tensor);

void three_nn_forward(int b, int n, int m, Tensor unknown_tensor,
Tensor known_tensor, Tensor dist2_tensor,
Tensor idx_tensor);

void bbox_overlaps(const Tensor bboxes1, const Tensor bboxes2, Tensor ious,
const int mode, const bool aligned, const int offset);

Expand Down Expand Up @@ -343,6 +356,18 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"softmax_focal_loss_backward", py::arg("input"), py::arg("target"),
py::arg("weight"), py::arg("buff"), py::arg("grad_input"),
py::arg("gamma"), py::arg("alpha"));
m.def("three_interpolate_forward", &three_interpolate_forward,
"three_interpolate_forward", py::arg("b"), py::arg("c"), py::arg("m"),
py::arg("n"), py::arg("points_tensor"), py::arg("idx_tensor"),
py::arg("weight_tensor"), py::arg("out_tensor"));
m.def("three_interpolate_backward", &three_interpolate_backward,
"three_interpolate_backward", py::arg("b"), py::arg("c"), py::arg("n"),
py::arg("m"), py::arg("grad_out_tensor"), py::arg("idx_tensor"),
py::arg("weight_tensor"), py::arg("grad_points_tensor"));
m.def("three_nn_forward", &three_nn_forward, "three_nn_forward", py::arg("b"),
py::arg("n"), py::arg("m"), py::arg("unknown_tensor"),
py::arg("known_tensor"), py::arg("dist2_tensor"),
py::arg("idx_tensor"));
m.def("bbox_overlaps", &bbox_overlaps, "bbox_overlaps", py::arg("bboxes1"),
py::arg("bboxes2"), py::arg("ious"), py::arg("mode"),
py::arg("aligned"), py::arg("offset"));
Expand Down
62 changes: 62 additions & 0 deletions mmcv/ops/csrc/pytorch/three_interpolate.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/interpolate.cpp

#include "pytorch_cpp_helper.hpp"

#ifdef MMCV_WITH_CUDA
void ThreeInterpolateForwardCUDAKernelLauncher(int b, int c, int m, int n,
const Tensor points,
const Tensor idx,
const Tensor weight, Tensor out);

void three_interpolate_forward_cuda(int b, int c, int m, int n,
const Tensor points, const Tensor idx,
const Tensor weight, Tensor out) {
ThreeInterpolateForwardCUDAKernelLauncher(b, c, m, n, points, idx, weight,
out);
};

void ThreeInterpolateBackwardCUDAKernelLauncher(int b, int c, int n, int m,
const Tensor grad_out,
const Tensor idx,
const Tensor weight,
Tensor grad_points);

void three_interpolate_backward_cuda(int b, int c, int n, int m,
const Tensor grad_out, const Tensor idx,
const Tensor weight, Tensor grad_points) {
ThreeInterpolateBackwardCUDAKernelLauncher(b, c, n, m, grad_out, idx, weight,
grad_points);
};
#endif

void three_interpolate_forward(int b, int c, int m, int n, Tensor points_tensor,
Tensor idx_tensor, Tensor weight_tensor,
Tensor out_tensor) {
if (points_tensor.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
three_interpolate_forward_cuda(b, c, m, n, points_tensor, idx_tensor,
weight_tensor, out_tensor);
#else
AT_ERROR("three_interpolate is not compiled with GPU support");
#endif
} else {
AT_ERROR("three_interpolate is not implemented on CPU");
}
}

void three_interpolate_backward(int b, int c, int n, int m,
Tensor grad_out_tensor, Tensor idx_tensor,
Tensor weight_tensor,
Tensor grad_points_tensor) {
if (grad_out_tensor.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
three_interpolate_backward_cuda(b, c, n, m, grad_out_tensor, idx_tensor,
weight_tensor, grad_points_tensor);
#else
AT_ERROR("three_interpolate is not compiled with GPU support");
#endif
} else {
AT_ERROR("three_interpolate is not implemented on CPU");
}
}
Loading