Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Features] Add stack ball query and stack group points ops #2292

Merged
merged 13 commits into from
Oct 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 50 additions & 21 deletions mmcv/ops/ball_query.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,44 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
from typing import Optional, Tuple

import torch
from torch.autograd import Function

from ..utils import ext_loader

ext_module = ext_loader.load_ext('_ext', ['ball_query_forward'])
ext_module = ext_loader.load_ext(
'_ext', ['ball_query_forward', 'stack_ball_query_forward'])


class BallQuery(Function):
"""Find nearby points in spherical space."""

@staticmethod
def forward(ctx, min_radius: float, max_radius: float, sample_num: int,
xyz: torch.Tensor, center_xyz: torch.Tensor) -> torch.Tensor:
def forward(
ctx,
min_radius: float,
max_radius: float,
sample_num: int,
xyz: torch.Tensor,
center_xyz: torch.Tensor,
xyz_batch_cnt: Optional[torch.Tensor] = None,
center_xyz_batch_cnt: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Args:
min_radius (float): minimum radius of the balls.
max_radius (float): maximum radius of the balls.
sample_num (int): maximum number of features in the balls.
xyz (torch.Tensor): (B, N, 3) xyz coordinates of the features.
xyz (torch.Tensor): (B, N, 3) xyz coordinates of the features,
or staked input (N1 + N2 ..., 3).
center_xyz (torch.Tensor): (B, npoint, 3) centers of the ball
query.
query, or staked input (M1 + M2 ..., 3).
xyz_batch_cnt: (batch_size): Stacked input xyz coordinates nums in
each batch, just like (N1, N2, ...). Defaults to None.
New in version 1.7.0.
center_xyz_batch_cnt: (batch_size): Stacked centers coordinates
nums in each batch, just line (M1, M2, ...). Defaults to None.
New in version 1.7.0.

Returns:
torch.Tensor: (B, npoint, nsample) tensor with the indices of the
Expand All @@ -31,21 +47,34 @@ def forward(ctx, min_radius: float, max_radius: float, sample_num: int,
assert center_xyz.is_contiguous()
assert xyz.is_contiguous()
assert min_radius < max_radius

B, N, _ = xyz.size()
npoint = center_xyz.size(1)
idx = xyz.new_zeros(B, npoint, sample_num, dtype=torch.int)

ext_module.ball_query_forward(
center_xyz,
xyz,
idx,
b=B,
n=N,
m=npoint,
min_radius=min_radius,
max_radius=max_radius,
nsample=sample_num)
if xyz_batch_cnt is not None and center_xyz_batch_cnt is not None:
assert xyz_batch_cnt.dtype == torch.int
assert center_xyz_batch_cnt.dtype == torch.int
idx = center_xyz.new_zeros((center_xyz.shape[0], sample_num),
dtype=torch.int32)
ext_module.stack_ball_query_forward(
center_xyz,
center_xyz_batch_cnt,
xyz,
xyz_batch_cnt,
idx,
max_radius=max_radius,
nsample=sample_num,
)
else:
B, N, _ = xyz.size()
npoint = center_xyz.size(1)
idx = xyz.new_zeros(B, npoint, sample_num, dtype=torch.int32)
ext_module.ball_query_forward(
center_xyz,
xyz,
idx,
b=B,
n=N,
m=npoint,
min_radius=min_radius,
max_radius=max_radius,
nsample=sample_num)
if torch.__version__ != 'parrots':
ctx.mark_non_differentiable(idx)
return idx
Expand Down
24 changes: 11 additions & 13 deletions mmcv/ops/csrc/common/cuda/correlation_cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -60,21 +60,19 @@ __global__ void correlation_forward_cuda_kernel(
for (int i = 0; i < kH; ++i) {
int i1 = start_i + i * dilationH;
int i2 = i1 + ph_dilated;
if
WITHIN_BOUNDS(i1, i2, iH, iH) {
for (int j = 0; j < kW; ++j) {
int j1 = start_j + j * dilationW;
int j2 = j1 + pw_dilated;
if
WITHIN_BOUNDS(j1, j2, iW, iW) {
for (int c = thread; c < C; c += WARP_SIZE) {
scalar_t v1 = rInput1[n][i1][j1][c];
scalar_t v2 = rInput2[n][i2][j2][c];
prod_sum += v1 * v2;
}
}
if (WITHIN_BOUNDS(i1, i2, iH, iH)) {
for (int j = 0; j < kW; ++j) {
int j1 = start_j + j * dilationW;
int j2 = j1 + pw_dilated;
if (WITHIN_BOUNDS(j1, j2, iW, iW)) {
for (int c = thread; c < C; c += WARP_SIZE) {
scalar_t v1 = rInput1[n][i1][j1][c];
scalar_t v2 = rInput2[n][i2][j2][c];
prod_sum += v1 * v2;
}
}
}
}
}
// accumulate
for (int offset = 16; offset > 0; offset /= 2)
Expand Down
68 changes: 68 additions & 0 deletions mmcv/ops/csrc/common/cuda/stack_ball_query_cuda_kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright (c) OpenMMLab. All rights reserved
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/ball_query_gpu.cu
#ifndef STACK_BALL_QUERY_CUDA_KERNEL_CUH
#define STACK_BALL_QUERY_CUDA_KERNEL_CUH

#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif

template <typename T>
__global__ void stack_ball_query_forward_cuda_kernel(
int B, int M, float radius, int nsample, const T *new_xyz,
const int *new_xyz_batch_cnt, const T *xyz, const int *xyz_batch_cnt,
int *idx) {
// :param xyz: (N1 + N2 ..., 3) xyz coordinates of the features
// :param xyz_batch_cnt: (batch_size), [N1, N2, ...]
// :param new_xyz: (M1 + M2 ..., 3) centers of the ball query
// :param new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
// output:
// idx: (M, nsample)
const T *cur_xyz = xyz;
int *cur_idx = idx;
CUDA_1D_KERNEL_LOOP(pt_idx, M) {
int bs_idx = 0;
for (int pt_cnt = 0; bs_idx < B; bs_idx++) {
pt_cnt += new_xyz_batch_cnt[bs_idx];
if (pt_idx < pt_cnt) break;
}

int xyz_batch_start_idx = 0;
for (int k = 0; k < bs_idx; k++) xyz_batch_start_idx += xyz_batch_cnt[k];

const T *new_xyz_p = new_xyz + pt_idx * 3;
cur_xyz += xyz_batch_start_idx * 3;
cur_idx += pt_idx * nsample;

float radius2 = radius * radius;
T new_x = new_xyz_p[0];
T new_y = new_xyz_p[1];
T new_z = new_xyz_p[2];
int n = xyz_batch_cnt[bs_idx];

int cnt = 0;
for (int k = 0; k < n; ++k) {
T x = cur_xyz[k * 3 + 0];
T y = cur_xyz[k * 3 + 1];
T z = cur_xyz[k * 3 + 2];
T d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) +
(new_z - z) * (new_z - z);
if (d2 < radius2) {
if (cnt == 0) {
for (int l = 0; l < nsample; ++l) {
cur_idx[l] = k;
}
}
cur_idx[cnt] = k;
++cnt;
if (cnt >= nsample) break;
}
}
if (cnt == 0) cur_idx[0] = -1;
}
}

#endif // STACK_BALL_QUERY_CUDA_KERNEL_CUH
97 changes: 97 additions & 0 deletions mmcv/ops/csrc/common/cuda/stack_group_points_cuda_kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright (c) OpenMMLab. All rights reserved.
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/group_points_gpu.cu
#ifndef STACK_GROUP_POINTS_CUDA_KERNEL_CUH
#define STACK_GROUP_POINTS_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
#include <stdio.h>
template <typename T>
__global__ void stack_group_points_forward_cuda_kernel(
int b, int c, int m, int nsample, const T *features,
const int *features_batch_cnt, const int *idx, const int *idx_batch_cnt,
T *out) {
// :param features: (N1 + N2 ..., C) tensor of features to group
// :param features_batch_cnt: (batch_size) [N1 + N2 ...] tensor containing the
// indices of features to group with :param idx: (M1 + M2 ..., nsample) tensor
// containing the indices of features to group with :param idx_batch_cnt:
// (batch_size) [M1 + M2 ...] tensor containing the indices of features to
// group with :return:
// output: (M1 + M2, C, nsample) tensor
CUDA_1D_KERNEL_LOOP(index, m * c * nsample) {
const T *cur_features = features;
const int *cur_idx = idx;
int sample_idx = index % nsample;
int c_idx = (index / nsample) % c;
int pt_idx = (index / nsample / c);

if (pt_idx >= m || c_idx >= c || sample_idx >= nsample) return;
int bs_idx = 0, pt_cnt = idx_batch_cnt[0];
for (int k = 1; k < b; k++) {
if (pt_idx < pt_cnt) break;
pt_cnt += idx_batch_cnt[k];
bs_idx = k;
}

int features_batch_start_idx = 0;
int features_batch_end_idx = features_batch_cnt[0];
for (int k = 0; k < bs_idx; k++) {
features_batch_start_idx += features_batch_cnt[k];
features_batch_end_idx =
features_batch_start_idx + features_batch_cnt[k + 1];
}
cur_features += features_batch_start_idx * c;

cur_idx += pt_idx * nsample + sample_idx;
int in_idx = cur_idx[0] * c + c_idx;
int out_idx = pt_idx * c * nsample + c_idx * nsample + sample_idx;
if (in_idx < features_batch_end_idx * c) {
out[out_idx] = cur_features[in_idx];
}
}
}

template <typename T>
__global__ void stack_group_points_backward_cuda_kernel(
int b, int c, int m, int n, int nsample, const T *grad_out, const int *idx,
const int *idx_batch_cnt, const int *features_batch_cnt, T *grad_features) {
// :param grad_out: (M1 + M2 ..., C, nsample) tensor of the gradients of the
// output from forward :param idx: (M1 + M2 ..., nsample) tensor containing
// the indices of features to group with :param idx_batch_cnt: (batch_size)
// [M1 + M2 ...] tensor containing the indices of features to group with
// :param features_batch_cnt: (batch_size) [N1 + N2 ...] tensor containing the
// indices of features to group with :return:
// grad_features: (N1 + N2 ..., C) gradient of the features
CUDA_1D_KERNEL_LOOP(index, m * c * nsample) {
const T *cur_grad_out = grad_out;
const int *cur_idx = idx;
T *cur_grad_features = grad_features;
int sample_idx = index % nsample;
int c_idx = (index / nsample) % c;
int pt_idx = (index / nsample / c);

if (pt_idx >= m || c_idx >= c || sample_idx >= nsample) return;

int bs_idx = 0, pt_cnt = idx_batch_cnt[0];
for (int k = 1; k < b; k++) {
if (pt_idx < pt_cnt) break;
pt_cnt += idx_batch_cnt[k];
bs_idx = k;
}

int features_batch_start_idx = 0;
for (int k = 0; k < bs_idx; k++)
features_batch_start_idx += features_batch_cnt[k];

cur_grad_out += pt_idx * c * nsample + c_idx * nsample + sample_idx;
cur_idx += pt_idx * nsample + sample_idx;
cur_grad_features += (features_batch_start_idx + cur_idx[0]) * c + c_idx;

atomicAdd(cur_grad_features, cur_grad_out[0]);
}
}

#endif // GROUP_POINTS_CUDA_KERNEL_CUH
1 change: 1 addition & 0 deletions mmcv/ops/csrc/common/pytorch_cuda_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ using at::Tensor;
using phalf = at::Half;

#define __PHALF(x) (x)
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))

#endif // PYTORCH_CUDA_HELPER
18 changes: 18 additions & 0 deletions mmcv/ops/csrc/pytorch/ball_query.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,21 @@ void ball_query_forward(Tensor new_xyz_tensor, Tensor xyz_tensor,
ball_query_forward_impl(b, n, m, min_radius, max_radius, nsample,
new_xyz_tensor, xyz_tensor, idx_tensor);
}

void stack_ball_query_forward_impl(float max_radius, int nsample,
const Tensor new_xyz,
const Tensor new_xyz_batch_cnt,
const Tensor xyz, const Tensor xyz_batch_cnt,
Tensor idx) {
DISPATCH_DEVICE_IMPL(stack_ball_query_forward_impl, max_radius, nsample,
new_xyz, new_xyz_batch_cnt, xyz, xyz_batch_cnt, idx);
}

void stack_ball_query_forward(Tensor new_xyz_tensor, Tensor new_xyz_batch_cnt,
Tensor xyz_tensor, Tensor xyz_batch_cnt,
Tensor idx_tensor, float max_radius,
int nsample) {
stack_ball_query_forward_impl(max_radius, nsample, new_xyz_tensor,
new_xyz_batch_cnt, xyz_tensor, xyz_batch_cnt,
idx_tensor);
}
Loading