Skip to content

Commit

Permalink
add stack sa model ops
Browse files Browse the repository at this point in the history
  • Loading branch information
VVsssssk committed Oct 14, 2022
1 parent 7d075d1 commit f218a98
Show file tree
Hide file tree
Showing 13 changed files with 770 additions and 65 deletions.
69 changes: 48 additions & 21 deletions mmcv/ops/ball_query.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,42 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
from typing import Optional, Tuple

import torch
from torch.autograd import Function

from ..utils import ext_loader

ext_module = ext_loader.load_ext('_ext', ['ball_query_forward'])
ext_module = ext_loader.load_ext(
'_ext', ['ball_query_forward', 'stack_ball_query_forward'])


class BallQuery(Function):
"""Find nearby points in spherical space."""

@staticmethod
def forward(ctx, min_radius: float, max_radius: float, sample_num: int,
xyz: torch.Tensor, center_xyz: torch.Tensor) -> torch.Tensor:
def forward(
ctx,
min_radius: float,
max_radius: float,
sample_num: int,
xyz: torch.Tensor,
center_xyz: torch.Tensor,
xyz_batch_cnt: Optional[torch.Tensor] = None,
center_xyz_batch_cnt: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Args:
min_radius (float): minimum radius of the balls.
max_radius (float): maximum radius of the balls.
sample_num (int): maximum number of features in the balls.
xyz (torch.Tensor): (B, N, 3) xyz coordinates of the features.
xyz (torch.Tensor): (B, N, 3) xyz coordinates of the features,
or staked input (N1 + N2 ..., 3).
center_xyz (torch.Tensor): (B, npoint, 3) centers of the ball
query.
query, or staked input (M1 + M2 ..., 3).
xyz_batch_cnt: (batch_size): Stacked input xyz coordinates nums in
each batch, just like (N1, N2, ...). Default None.
center_xyz_batch_cnt: (batch_size): Stacked centers coordinates
nums in each batch, just line (M1, M2, ...). Default None.
Returns:
torch.Tensor: (B, npoint, nsample) tensor with the indices of the
Expand All @@ -31,21 +45,34 @@ def forward(ctx, min_radius: float, max_radius: float, sample_num: int,
assert center_xyz.is_contiguous()
assert xyz.is_contiguous()
assert min_radius < max_radius

B, N, _ = xyz.size()
npoint = center_xyz.size(1)
idx = xyz.new_zeros(B, npoint, sample_num, dtype=torch.int)

ext_module.ball_query_forward(
center_xyz,
xyz,
idx,
b=B,
n=N,
m=npoint,
min_radius=min_radius,
max_radius=max_radius,
nsample=sample_num)
if xyz_batch_cnt is not None and center_xyz_batch_cnt is not None:
assert xyz_batch_cnt.dtype == torch.int
assert center_xyz_batch_cnt.dtype == torch.int
idx = center_xyz.new_zeros((center_xyz.shape[0], sample_num),
dtype=torch.int32)
ext_module.stack_ball_query_forward(
center_xyz,
center_xyz_batch_cnt,
xyz,
xyz_batch_cnt,
idx,
max_radius=max_radius,
nsample=sample_num,
)
else:
B, N, _ = xyz.size()
npoint = center_xyz.size(1)
idx = xyz.new_zeros(B, npoint, sample_num, dtype=torch.int32)
ext_module.ball_query_forward(
center_xyz,
xyz,
idx,
b=B,
n=N,
m=npoint,
min_radius=min_radius,
max_radius=max_radius,
nsample=sample_num)
if torch.__version__ != 'parrots':
ctx.mark_non_differentiable(idx)
return idx
Expand Down
75 changes: 75 additions & 0 deletions mmcv/ops/csrc/common/cuda/stack_ball_query_cuda_kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright (c) OpenMMLab. All rights reserved
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/ball_query_gpu.cu
#ifndef STACK_BALL_QUERY_CUDA_KERNEL_CUH
#define STACK_BALL_QUERY_CUDA_KERNEL_CUH

#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif

template <typename T>
__global__ void stack_ball_query_forward_cuda_kernel(int B, int M, float radius, int nsample,
const T *new_xyz,
const int *new_xyz_batch_cnt,
const T *xyz, const int *xyz_batch_cnt,
int *idx) {
// :param xyz: (N1 + N2 ..., 3) xyz coordinates of the features
// :param xyz_batch_cnt: (batch_size), [N1, N2, ...]
// :param new_xyz: (M1 + M2 ..., 3) centers of the ball query
// :param new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
// output:
// idx: (M, nsample)
const T *cur_xyz = xyz;
int *cur_idx = idx;
CUDA_1D_KERNEL_LOOP(pt_idx, M) {
int bs_idx = 0;
for (int pt_cnt = 0; bs_idx < B; bs_idx++) {
pt_cnt += new_xyz_batch_cnt[bs_idx];
if (pt_idx < pt_cnt)
break;
}

int xyz_batch_start_idx = 0;
for (int k = 0; k < bs_idx; k++)
xyz_batch_start_idx += xyz_batch_cnt[k];
// for (int k = 0; k < bs_idx; k++) new_xyz_batch_start_idx +=
// new_xyz_batch_cnt[k];

const T* new_xyz_p = new_xyz + pt_idx * 3;
cur_xyz += xyz_batch_start_idx * 3;
cur_idx += pt_idx * nsample;

float radius2 = radius * radius;
T new_x = new_xyz_p[0];
T new_y = new_xyz_p[1];
T new_z = new_xyz_p[2];
int n = xyz_batch_cnt[bs_idx];

int cnt = 0;
for (int k = 0; k < n; ++k) {
T x = cur_xyz[k * 3 + 0];
T y = cur_xyz[k * 3 + 1];
T z = cur_xyz[k * 3 + 2];
T d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) +
(new_z - z) * (new_z - z);
if (d2 < radius2) {
if (cnt == 0) {
for (int l = 0; l < nsample; ++l) {
cur_idx[l] = k;
}
}
cur_idx[cnt] = k;
++cnt;
if (cnt >= nsample)
break;
}
}
if (cnt == 0)
cur_idx[0] = -1;
}
}

#endif // STACK_BALL_QUERY_CUDA_KERNEL_CUH
83 changes: 83 additions & 0 deletions mmcv/ops/csrc/common/cuda/stack_group_points_cuda_kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright (c) OpenMMLab. All rights reserved.
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/group_points_gpu.cu
#ifndef STACK_GROUP_POINTS_CUDA_KERNEL_CUH
#define STACK_GROUP_POINTS_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif

template <typename T>
__global__ void stack_group_points_forward_cuda_kernel(int b, int c, int m, int nsample,
const T *features, const int *features_batch_cnt, const int *idx, const int *idx_batch_cnt, T *out) {
// :param features: (N1 + N2 ..., C) tensor of features to group
// :param features_batch_cnt: (batch_size) [N1 + N2 ...] tensor containing the indices of features to group with
// :param idx: (M1 + M2 ..., nsample) tensor containing the indices of features to group with
// :param idx_batch_cnt: (batch_size) [M1 + M2 ...] tensor containing the indices of features to group with
// :return:
// output: (M1 + M2, C, nsample) tensor
const T *cur_features = features;
const int *cur_idx = idx;
CUDA_1D_KERNEL_LOOP(index, m * c * nsample){
int sample_idx = index % nsample;
int c_idx = (index / nsample) % c;
int pt_idx = (index / nsample / c);

if (c_idx >= c || sample_idx >= nsample) break;
int bs_idx = 0, pt_cnt = idx_batch_cnt[0];
for (int pt_cnt = 0; bs_idx < b; bs_idx++){
pt_cnt += idx_batch_cnt[bs_idx];
if (pt_idx < pt_cnt) break;
}

int features_batch_start_idx = 0;
for (int k = 0; k < bs_idx; k++) features_batch_start_idx += features_batch_cnt[k];
cur_features += features_batch_start_idx * c;

cur_idx += pt_idx * nsample + sample_idx;
int in_idx = cur_idx[0] * c + c_idx;
int out_idx = pt_idx * c * nsample + c_idx * nsample + sample_idx;

out[out_idx] = cur_features[in_idx];
}
}

template <typename T>
__global__ void stack_group_points_backward_cuda_kernel(int b, int c, int m, int n, int nsample,
const T *grad_out, const int *idx, const int *idx_batch_cnt, const int *features_batch_cnt, T *grad_features) {
// :param grad_out: (M1 + M2 ..., C, nsample) tensor of the gradients of the output from forward
// :param idx: (M1 + M2 ..., nsample) tensor containing the indices of features to group with
// :param idx_batch_cnt: (batch_size) [M1 + M2 ...] tensor containing the indices of features to group with
// :param features_batch_cnt: (batch_size) [N1 + N2 ...] tensor containing the indices of features to group with
// :return:
// grad_features: (N1 + N2 ..., C) gradient of the features
const T *cur_grad_out = grad_out;
const int *cur_idx = idx;
CUDA_1D_KERNEL_LOOP(index, m * c * nsample){
int sample_idx = index % nsample;
int c_idx = (index / nsample) % c;
int pt_idx = (index / nsample / c);

if (c_idx >= c || sample_idx >= nsample) break;

int bs_idx = 0, pt_cnt = idx_batch_cnt[0];
for (int k = 1; k < b; k++){
if (pt_idx < pt_cnt) break;
pt_cnt += idx_batch_cnt[k];
bs_idx = k;
}

int features_batch_start_idx = 0;
for (int k = 0; k < bs_idx; k++) features_batch_start_idx += features_batch_cnt[k];

cur_grad_out += pt_idx * c * nsample + c_idx * nsample + sample_idx;
cur_idx += pt_idx * nsample + sample_idx;
grad_features += (features_batch_start_idx + cur_idx[0]) * c + c_idx;

atomicAdd(grad_features, cur_grad_out[0]);
}
}

#endif // GROUP_POINTS_CUDA_KERNEL_CUH
1 change: 1 addition & 0 deletions mmcv/ops/csrc/common/pytorch_cuda_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ using at::Tensor;
using phalf = at::Half;

#define __PHALF(x) (x)
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))

#endif // PYTORCH_CUDA_HELPER
16 changes: 16 additions & 0 deletions mmcv/ops/csrc/pytorch/ball_query.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,19 @@ void ball_query_forward(Tensor new_xyz_tensor, Tensor xyz_tensor,
ball_query_forward_impl(b, n, m, min_radius, max_radius, nsample,
new_xyz_tensor, xyz_tensor, idx_tensor);
}

void stack_ball_query_forward_impl(
float max_radius, int nsample,
const Tensor new_xyz,const Tensor new_xyz_batch_cnt, const Tensor xyz,const Tensor xyz_batch_cnt,
Tensor idx) {
DISPATCH_DEVICE_IMPL(stack_ball_query_forward_impl, max_radius,
nsample, new_xyz, new_xyz_batch_cnt, xyz, xyz_batch_cnt, idx);
}

void stack_ball_query_forward(Tensor new_xyz_tensor,Tensor new_xyz_batch_cnt, Tensor xyz_tensor,
Tensor xyz_batch_cnt,
Tensor idx_tensor,
float max_radius, int nsample) {
stack_ball_query_forward_impl(max_radius, nsample, new_xyz_tensor, new_xyz_batch_cnt, xyz_tensor,
xyz_batch_cnt, idx_tensor);
}
55 changes: 55 additions & 0 deletions mmcv/ops/csrc/pytorch/cuda/cudabind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,24 @@ void ball_query_forward_impl(int b, int n, int m, float min_radius,
Tensor idx);
REGISTER_DEVICE_IMPL(ball_query_forward_impl, CUDA, ball_query_forward_cuda);

void StackBallQueryForwardCUDAKernelLauncher(float max_radius, int nsample,
const Tensor new_xyz,const Tensor new_xyz_batch_cnt, const Tensor xyz,const Tensor xyz_batch_cnt,
Tensor idx);

void stack_ball_query_forward_cuda(
float max_radius, int nsample,
const Tensor new_xyz,const Tensor new_xyz_batch_cnt, const Tensor xyz,const Tensor xyz_batch_cnt,
Tensor idx) {
StackBallQueryForwardCUDAKernelLauncher(max_radius,
nsample, new_xyz,new_xyz_batch_cnt, xyz,xyz_batch_cnt, idx);
};

void stack_ball_query_forward_impl(
float max_radius, int nsample,
const Tensor new_xyz,const Tensor new_xyz_batch_cnt, const Tensor xyz,const Tensor xyz_batch_cnt,
Tensor idx);
REGISTER_DEVICE_IMPL(stack_ball_query_forward_impl, CUDA, stack_ball_query_forward_cuda);

void BBoxOverlapsCUDAKernelLauncher(const Tensor bboxes1, const Tensor bboxes2,
Tensor ious, const int mode,
const bool aligned, const int offset);
Expand Down Expand Up @@ -564,6 +582,43 @@ REGISTER_DEVICE_IMPL(group_points_forward_impl, CUDA,
REGISTER_DEVICE_IMPL(group_points_backward_impl, CUDA,
group_points_backward_cuda);

void StackGroupPointsForwardCUDAKernelLauncher(int b, int c, int m, int nsample,
const Tensor features_tensor, const Tensor features_batch_cnt_tensor,
const Tensor idx_tensor, const Tensor idx_batch_cnt_tensor,
Tensor out_tensor);
void StackGroupPointsBackwardCUDAKernelLauncher(int b, int c, int m, int n, int nsample,
const Tensor grad_out_tensor, const Tensor idx_tensor, const Tensor idx_batch_cnt_tensor, const Tensor features_batch_cnt_tensor,
Tensor grad_features_tensor);

void stack_group_points_forward_cuda(int b, int c, int m, int nsample,
const Tensor features_tensor, const Tensor features_batch_cnt_tensor,
const Tensor idx_tensor, const Tensor idx_batch_cnt_tensor,
Tensor out_tensor){
StackGroupPointsForwardCUDAKernelLauncher(b, c, m, nsample, features_tensor,features_batch_cnt_tensor,
idx_tensor,idx_batch_cnt_tensor,out_tensor);
};

void stack_group_points_backward_cuda(int b, int c, int m, int n, int nsample,
const Tensor grad_out_tensor, const Tensor idx_tensor, const Tensor idx_batch_cnt_tensor, const Tensor features_batch_cnt_tensor,
Tensor grad_features_tensor){
StackGroupPointsBackwardCUDAKernelLauncher(b, c, m, n, nsample, grad_out_tensor,
idx_tensor, idx_batch_cnt_tensor, features_batch_cnt_tensor,grad_features_tensor);
};

void stack_group_points_forward_impl(int b, int c, int m, int nsample,
const Tensor features_tensor, const Tensor features_batch_cnt_tensor,
const Tensor idx_tensor, const Tensor idx_batch_cnt_tensor,
Tensor out_tensor);

void stack_group_points_backward_impl(int b, int c, int m, int n, int nsample,
const Tensor grad_out_tensor, const Tensor idx_tensor, const Tensor idx_batch_cnt_tensor, const Tensor features_batch_cnt_tensor,
Tensor grad_features_tensor);

REGISTER_DEVICE_IMPL(stack_group_points_forward_impl, CUDA,
stack_group_points_forward_cuda);
REGISTER_DEVICE_IMPL(stack_group_points_backward_impl, CUDA,
stack_group_points_backward_cuda);

void IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(const int num_a,
const Tensor boxes_a,
const int num_b,
Expand Down
42 changes: 42 additions & 0 deletions mmcv/ops/csrc/pytorch/cuda/stack_ball_query_cuda.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (c) OpenMMLab. All rights reserved
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/ball_query_gpu.cu

#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "stack_ball_query_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))


void StackBallQueryForwardCUDAKernelLauncher(float max_radius, int nsample,
const Tensor new_xyz,const Tensor new_xyz_batch_cnt, const Tensor xyz,const Tensor xyz_batch_cnt,
Tensor idx) {
at::cuda::CUDAGuard device_guard(new_xyz.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

// const float *new_xyz_ptr = new_xyz.data_ptr<float>();
// const float *xyz_ptr = xyz.data_ptr<float>();
// const int *new_xyz_batch_cnt_ptr = new_xyz_batch_cnt.data_ptr<int>();
// const int *xyz_batch_cnt_ptr = xyz_batch_cnt.data_ptr<int>();
// int *idx_ptr = idx.data_ptr<int>();

int B = xyz_batch_cnt.size(0);
int M = new_xyz.size(0);


// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(DIVUP(M, THREADS_PER_BLOCK));
dim3 threads(THREADS_PER_BLOCK);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
new_xyz.scalar_type(), "stack_ball_query_forward_cuda_kernel", [&] {
stack_ball_query_forward_cuda_kernel<scalar_t>
<<<blocks, threads, 0, stream>>>(
B, M, max_radius, nsample, new_xyz.data_ptr<scalar_t>(), new_xyz_batch_cnt.data_ptr<int>(), xyz.data_ptr<scalar_t>(), xyz_batch_cnt.data_ptr<int>(),
idx.data_ptr<int>());
});

AT_CUDA_CHECK(cudaGetLastError());
}
Loading

0 comments on commit f218a98

Please sign in to comment.