-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
VVsssssk
committed
Oct 14, 2022
1 parent
7d075d1
commit f218a98
Showing
13 changed files
with
770 additions
and
65 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
75 changes: 75 additions & 0 deletions
75
mmcv/ops/csrc/common/cuda/stack_ball_query_cuda_kernel.cuh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved | ||
// Modified from | ||
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/ball_query_gpu.cu | ||
#ifndef STACK_BALL_QUERY_CUDA_KERNEL_CUH | ||
#define STACK_BALL_QUERY_CUDA_KERNEL_CUH | ||
|
||
#ifdef MMCV_USE_PARROTS | ||
#include "parrots_cuda_helper.hpp" | ||
#else | ||
#include "pytorch_cuda_helper.hpp" | ||
#endif | ||
|
||
template <typename T> | ||
__global__ void stack_ball_query_forward_cuda_kernel(int B, int M, float radius, int nsample, | ||
const T *new_xyz, | ||
const int *new_xyz_batch_cnt, | ||
const T *xyz, const int *xyz_batch_cnt, | ||
int *idx) { | ||
// :param xyz: (N1 + N2 ..., 3) xyz coordinates of the features | ||
// :param xyz_batch_cnt: (batch_size), [N1, N2, ...] | ||
// :param new_xyz: (M1 + M2 ..., 3) centers of the ball query | ||
// :param new_xyz_batch_cnt: (batch_size), [M1, M2, ...] | ||
// output: | ||
// idx: (M, nsample) | ||
const T *cur_xyz = xyz; | ||
int *cur_idx = idx; | ||
CUDA_1D_KERNEL_LOOP(pt_idx, M) { | ||
int bs_idx = 0; | ||
for (int pt_cnt = 0; bs_idx < B; bs_idx++) { | ||
pt_cnt += new_xyz_batch_cnt[bs_idx]; | ||
if (pt_idx < pt_cnt) | ||
break; | ||
} | ||
|
||
int xyz_batch_start_idx = 0; | ||
for (int k = 0; k < bs_idx; k++) | ||
xyz_batch_start_idx += xyz_batch_cnt[k]; | ||
// for (int k = 0; k < bs_idx; k++) new_xyz_batch_start_idx += | ||
// new_xyz_batch_cnt[k]; | ||
|
||
const T* new_xyz_p = new_xyz + pt_idx * 3; | ||
cur_xyz += xyz_batch_start_idx * 3; | ||
cur_idx += pt_idx * nsample; | ||
|
||
float radius2 = radius * radius; | ||
T new_x = new_xyz_p[0]; | ||
T new_y = new_xyz_p[1]; | ||
T new_z = new_xyz_p[2]; | ||
int n = xyz_batch_cnt[bs_idx]; | ||
|
||
int cnt = 0; | ||
for (int k = 0; k < n; ++k) { | ||
T x = cur_xyz[k * 3 + 0]; | ||
T y = cur_xyz[k * 3 + 1]; | ||
T z = cur_xyz[k * 3 + 2]; | ||
T d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + | ||
(new_z - z) * (new_z - z); | ||
if (d2 < radius2) { | ||
if (cnt == 0) { | ||
for (int l = 0; l < nsample; ++l) { | ||
cur_idx[l] = k; | ||
} | ||
} | ||
cur_idx[cnt] = k; | ||
++cnt; | ||
if (cnt >= nsample) | ||
break; | ||
} | ||
} | ||
if (cnt == 0) | ||
cur_idx[0] = -1; | ||
} | ||
} | ||
|
||
#endif // STACK_BALL_QUERY_CUDA_KERNEL_CUH |
83 changes: 83 additions & 0 deletions
83
mmcv/ops/csrc/common/cuda/stack_group_points_cuda_kernel.cuh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved. | ||
// Modified from | ||
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/group_points_gpu.cu | ||
#ifndef STACK_GROUP_POINTS_CUDA_KERNEL_CUH | ||
#define STACK_GROUP_POINTS_CUDA_KERNEL_CUH | ||
#ifdef MMCV_USE_PARROTS | ||
#include "parrots_cuda_helper.hpp" | ||
#else | ||
#include "pytorch_cuda_helper.hpp" | ||
#endif | ||
|
||
template <typename T> | ||
__global__ void stack_group_points_forward_cuda_kernel(int b, int c, int m, int nsample, | ||
const T *features, const int *features_batch_cnt, const int *idx, const int *idx_batch_cnt, T *out) { | ||
// :param features: (N1 + N2 ..., C) tensor of features to group | ||
// :param features_batch_cnt: (batch_size) [N1 + N2 ...] tensor containing the indices of features to group with | ||
// :param idx: (M1 + M2 ..., nsample) tensor containing the indices of features to group with | ||
// :param idx_batch_cnt: (batch_size) [M1 + M2 ...] tensor containing the indices of features to group with | ||
// :return: | ||
// output: (M1 + M2, C, nsample) tensor | ||
const T *cur_features = features; | ||
const int *cur_idx = idx; | ||
CUDA_1D_KERNEL_LOOP(index, m * c * nsample){ | ||
int sample_idx = index % nsample; | ||
int c_idx = (index / nsample) % c; | ||
int pt_idx = (index / nsample / c); | ||
|
||
if (c_idx >= c || sample_idx >= nsample) break; | ||
int bs_idx = 0, pt_cnt = idx_batch_cnt[0]; | ||
for (int pt_cnt = 0; bs_idx < b; bs_idx++){ | ||
pt_cnt += idx_batch_cnt[bs_idx]; | ||
if (pt_idx < pt_cnt) break; | ||
} | ||
|
||
int features_batch_start_idx = 0; | ||
for (int k = 0; k < bs_idx; k++) features_batch_start_idx += features_batch_cnt[k]; | ||
cur_features += features_batch_start_idx * c; | ||
|
||
cur_idx += pt_idx * nsample + sample_idx; | ||
int in_idx = cur_idx[0] * c + c_idx; | ||
int out_idx = pt_idx * c * nsample + c_idx * nsample + sample_idx; | ||
|
||
out[out_idx] = cur_features[in_idx]; | ||
} | ||
} | ||
|
||
template <typename T> | ||
__global__ void stack_group_points_backward_cuda_kernel(int b, int c, int m, int n, int nsample, | ||
const T *grad_out, const int *idx, const int *idx_batch_cnt, const int *features_batch_cnt, T *grad_features) { | ||
// :param grad_out: (M1 + M2 ..., C, nsample) tensor of the gradients of the output from forward | ||
// :param idx: (M1 + M2 ..., nsample) tensor containing the indices of features to group with | ||
// :param idx_batch_cnt: (batch_size) [M1 + M2 ...] tensor containing the indices of features to group with | ||
// :param features_batch_cnt: (batch_size) [N1 + N2 ...] tensor containing the indices of features to group with | ||
// :return: | ||
// grad_features: (N1 + N2 ..., C) gradient of the features | ||
const T *cur_grad_out = grad_out; | ||
const int *cur_idx = idx; | ||
CUDA_1D_KERNEL_LOOP(index, m * c * nsample){ | ||
int sample_idx = index % nsample; | ||
int c_idx = (index / nsample) % c; | ||
int pt_idx = (index / nsample / c); | ||
|
||
if (c_idx >= c || sample_idx >= nsample) break; | ||
|
||
int bs_idx = 0, pt_cnt = idx_batch_cnt[0]; | ||
for (int k = 1; k < b; k++){ | ||
if (pt_idx < pt_cnt) break; | ||
pt_cnt += idx_batch_cnt[k]; | ||
bs_idx = k; | ||
} | ||
|
||
int features_batch_start_idx = 0; | ||
for (int k = 0; k < bs_idx; k++) features_batch_start_idx += features_batch_cnt[k]; | ||
|
||
cur_grad_out += pt_idx * c * nsample + c_idx * nsample + sample_idx; | ||
cur_idx += pt_idx * nsample + sample_idx; | ||
grad_features += (features_batch_start_idx + cur_idx[0]) * c + c_idx; | ||
|
||
atomicAdd(grad_features, cur_grad_out[0]); | ||
} | ||
} | ||
|
||
#endif // GROUP_POINTS_CUDA_KERNEL_CUH |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved | ||
// Modified from | ||
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/ball_query_gpu.cu | ||
|
||
#include <math.h> | ||
#include <stdio.h> | ||
#include <stdlib.h> | ||
#include "stack_ball_query_cuda_kernel.cuh" | ||
#include "pytorch_cuda_helper.hpp" | ||
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) | ||
|
||
|
||
void StackBallQueryForwardCUDAKernelLauncher(float max_radius, int nsample, | ||
const Tensor new_xyz,const Tensor new_xyz_batch_cnt, const Tensor xyz,const Tensor xyz_batch_cnt, | ||
Tensor idx) { | ||
at::cuda::CUDAGuard device_guard(new_xyz.device()); | ||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
|
||
// const float *new_xyz_ptr = new_xyz.data_ptr<float>(); | ||
// const float *xyz_ptr = xyz.data_ptr<float>(); | ||
// const int *new_xyz_batch_cnt_ptr = new_xyz_batch_cnt.data_ptr<int>(); | ||
// const int *xyz_batch_cnt_ptr = xyz_batch_cnt.data_ptr<int>(); | ||
// int *idx_ptr = idx.data_ptr<int>(); | ||
|
||
int B = xyz_batch_cnt.size(0); | ||
int M = new_xyz.size(0); | ||
|
||
|
||
// blockIdx.x(col), blockIdx.y(row) | ||
dim3 blocks(DIVUP(M, THREADS_PER_BLOCK)); | ||
dim3 threads(THREADS_PER_BLOCK); | ||
|
||
AT_DISPATCH_FLOATING_TYPES_AND_HALF( | ||
new_xyz.scalar_type(), "stack_ball_query_forward_cuda_kernel", [&] { | ||
stack_ball_query_forward_cuda_kernel<scalar_t> | ||
<<<blocks, threads, 0, stream>>>( | ||
B, M, max_radius, nsample, new_xyz.data_ptr<scalar_t>(), new_xyz_batch_cnt.data_ptr<int>(), xyz.data_ptr<scalar_t>(), xyz_batch_cnt.data_ptr<int>(), | ||
idx.data_ptr<int>()); | ||
}); | ||
|
||
AT_CUDA_CHECK(cudaGetLastError()); | ||
} |
Oops, something went wrong.