Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add points_in_polygons CUDA op for rotated detection. #1600

Merged
merged 20 commits into from
Dec 24, 2021
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/en/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
- KNN
- MaskedConv
- NMS
- PointsInPolygons
- PSAMask
- RoIPointPool3d
- RoIPool
Expand Down
1 change: 1 addition & 0 deletions docs/zh_cn/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
- KNN
- MaskedConv
- NMS
- PointsInPolygons
- PSAMask
- RoIPointPool3d
- RoIPool
Expand Down
4 changes: 3 additions & 1 deletion mmcv/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
rel_roi_point_to_rel_img_point)
from .points_in_boxes import (points_in_boxes_all, points_in_boxes_cpu,
points_in_boxes_part)
from .points_in_polygons import points_in_polygons
from .points_sampler import PointsSampler
from .psa_mask import PSAMask
from .roi_align import RoIAlign, roi_align
Expand Down Expand Up @@ -77,5 +78,6 @@
'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation',
'boxes_iou_bev', 'nms_bev', 'nms_normal_bev', 'Voxelization',
'voxelization', 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d',
'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all'
'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all',
'points_in_polygons'
]
79 changes: 79 additions & 0 deletions mmcv/ops/csrc/common/cuda/points_in_polygons_cuda_kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef POINTS_IN_POLYGONS_CUDA_KERNEL_CUH
#define POINTS_IN_POLYGONS_CUDA_KERNEL_CUH

#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif

struct point {
float x, y;
};

template <typename scalar_t>
__global__ void points_in_polygons_forward_cuda_kernel(
const int nthreads, const scalar_t *vertex1, const scalar_t *vertex2,
const int rows, const int cols, scalar_t *inside_flag) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int row = index / cols;
int col = index % cols;

const scalar_t *offset_vertex1 = vertex1 + row * 2;
const scalar_t *offset_vertex2 = vertex2 + col * 8;

point point_[1];
point polygon[4];

point_[0].x = offset_vertex1[0];
point_[0].y = offset_vertex1[1];

polygon[0].x = offset_vertex2[0];
polygon[0].y = offset_vertex2[1];
polygon[1].x = offset_vertex2[2];
polygon[1].y = offset_vertex2[3];
polygon[2].x = offset_vertex2[4];
polygon[2].y = offset_vertex2[5];
polygon[3].x = offset_vertex2[6];
polygon[3].y = offset_vertex2[7];

int nCross = 0;
int i, j;
float sx, sy, tx, ty, px, py, x;
for (i = 0, j = 3; i < 4; j = i, i++) {
sx = polygon[i].x;
sy = polygon[i].y;
tx = polygon[j].x;
ty = polygon[j].y;

px = point_[0].x;
py = point_[0].y;

if (py < min(sy, ty)) continue;
if (py > max(sy, ty)) continue;

if ((sx == px && sy == py) || (tx == px && ty == py)) {
break;
} else {
if ((sy < py && ty >= py) || (sy >= py && ty < py)) {
x = sx + (py - sy) * (tx - sx) / (ty - sy);
if (x == px) {
break;
}
if (x > px) {
nCross++;
}
}
}
}
if (nCross % 2 == 1) {
inside_flag[index] = 1.0;
} else {
inside_flag[index] = 0.0;
}
return;
}
}

#endif // POINTS_IN_POLYGONS_CUDA_KERNEL_CUH
19 changes: 19 additions & 0 deletions mmcv/ops/csrc/pytorch/cuda/cudabind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1362,3 +1362,22 @@ REGISTER_DEVICE_IMPL(hard_voxelize_forward_impl, CUDA,
hard_voxelize_forward_cuda);
REGISTER_DEVICE_IMPL(dynamic_voxelize_forward_impl, CUDA,
dynamic_voxelize_forward_cuda);

void PointsInPolygonsForwardCUDAKernelLauncher(const at::Tensor points,
const at::Tensor polygons,
const int rows, const int cols,
at::Tensor output);

void points_in_polygons_forward_cuda(const Tensor points, const Tensor polygons,
Tensor output, const int rows,
const int cols) {
PointsInPolygonsForwardCUDAKernelLauncher(points, polygons, rows, cols,
output);
};

void points_in_polygons_forward_impl(const Tensor points, const Tensor polygons,
Tensor output, const int rows,
const int cols);

REGISTER_DEVICE_IMPL(points_in_polygons_forward_impl, CUDA,
points_in_polygons_forward_cuda);
28 changes: 28 additions & 0 deletions mmcv/ops/csrc/pytorch/cuda/points_in_polygons_cuda.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (c) OpenMMLab. All rights reserved
// Modified from
// https://github.com/ming71/CUDA/blob/master/point_justify/points_justify_kernel.cu

#include <stdio.h>

#include "points_in_polygons_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"

void PointsInPolygonsForwardCUDAKernelLauncher(const at::Tensor points,
const at::Tensor polygons,
const int rows, const int cols,
at::Tensor output) {
const int output_size = rows * cols;
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
points.scalar_type(), "PointsInPolygonsLaucher", ([&] {
zytx121 marked this conversation as resolved.
Show resolved Hide resolved
const scalar_t *vertex1 = points.data_ptr<scalar_t>();
const scalar_t *vertex2 = polygons.data_ptr<scalar_t>();
scalar_t *inside_flag = output.data_ptr<scalar_t>();

points_in_polygons_forward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, vertex1, vertex2, rows, cols, inside_flag);
}));
AT_CUDA_CHECK(cudaGetLastError());
}
15 changes: 15 additions & 0 deletions mmcv/ops/csrc/pytorch/points_in_polygons.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"

void points_in_polygons_forward_impl(const Tensor points, const Tensor polygons,
Tensor output, const int rows,
const int cols) {
DISPATCH_DEVICE_IMPL(points_in_polygons_forward_impl, points, polygons,
output, rows, cols);
}

void points_in_polygons_forward(Tensor points, Tensor polygons, Tensor output) {
int rows = points.size(0);
int cols = polygons.size(0);
points_in_polygons_forward_impl(points, polygons, output, rows, cols);
}
5 changes: 5 additions & 0 deletions mmcv/ops/csrc/pytorch/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,8 @@ void correlation_backward(Tensor grad_output, Tensor input1, Tensor input2,
int dilationH, int dilationW, int dilation_patchH,
int dilation_patchW, int dH, int dW);

void points_in_polygons_forward(Tensor points, Tensor polygons, Tensor output);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"),
py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"),
Expand Down Expand Up @@ -686,4 +688,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"roiaware_pool3d_backward", py::arg("pts_idx_of_voxels"),
py::arg("argmax"), py::arg("grad_out"), py::arg("grad_in"),
py::arg("pool_method"));
m.def("points_in_polygons_forward", &points_in_polygons_forward,
"points_in_polygons_forward", py::arg("points"), py::arg("polygons"),
py::arg("output"));
}
37 changes: 37 additions & 0 deletions mmcv/ops/points_in_polygons.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch

from ..utils import ext_loader

ext_module = ext_loader.load_ext('_ext', ['points_in_polygons_forward'])


def points_in_polygons(points, polygons):
"""Judging whether points are inside polygons, which is used in the ATSS
assignment for the rotated boxes.

It should be noted that when the point is just at the polygon boundary, the
judgment will be inaccurate, but the effect on assignment is limited.

Args:
points (torch.Tensor): It has shape (B, 2), indicating (x, y).
M means the number of predicted points.
polygons (torch.Tensor): It has shape (M, 8), indicating
(x1, y1, x2, y2, x3, y3, x4, y4). M means the number of
ground truth polygons.

Returns:
torch.Tensor: Return the result with the shape of (B, M),
1 indicates that the point is inside the polygon,
0 indicates that the point is outside the polygon.
"""
assert points.shape[1] == 2, \
'points dimension should be 2, ' \
f'but got unexpected shape {points.shape[1]}'
assert polygons.shape[1] == 8, \
'polygons dimension should be 8, ' \
f'but got unexpected shape {polygons.shape[1]}'
output = torch.full([points.shape[0], polygons.shape[0]],
0.).cuda().float()
ext_module.points_in_polygons_forward(points.contiguous(),
polygons.contiguous(), output)
return output
22 changes: 22 additions & 0 deletions tests/test_ops/test_points_in_polygons.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import numpy as np
import pytest
import torch

from mmcv.ops import points_in_polygons


@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_points_in_polygons():
points = np.array([[300., 300.], [400., 400.], [100., 100], [300, 250],
[100, 0]])
polygons = np.array([[200., 200., 400., 400., 500., 200., 400., 100.],
[400., 400., 500., 500., 600., 300., 500., 200.],
[300., 300., 600., 700., 700., 700., 700., 100.]])
expected_output = np.array([[0., 0., 0.], [0., 0., 1.], [0., 0., 0.],
[1., 0., 0.], [0., 0., 0.]])
points = torch.from_numpy(points).cuda().float()
polygons = torch.from_numpy(polygons).cuda().float()
expert_output = torch.from_numpy(expert_output).cuda().float()
zytx121 marked this conversation as resolved.
Show resolved Hide resolved
assert torch.allclose(
points_in_polygons(points, polygons), expert_output, 1e-3)
zytx121 marked this conversation as resolved.
Show resolved Hide resolved