Skip to content

Commit

Permalink
[Feature] Add points_in_polygons CUDA op for rotated detection. (#1600)
Browse files Browse the repository at this point in the history
  • Loading branch information
zytx121 authored Dec 24, 2021
1 parent a4dc2a7 commit 304efbb
Show file tree
Hide file tree
Showing 10 changed files with 210 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/en/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
- KNN
- MaskedConv
- NMS
- PointsInPolygons
- PSAMask
- RiRoIAlignRotated
- RotatedFeatureAlign
Expand Down
1 change: 1 addition & 0 deletions docs/zh_cn/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
- KNN
- MaskedConv
- NMS
- PointsInPolygons
- PSAMask
- RotatedFeatureAlign
- RoIPointPool3d
Expand Down
4 changes: 3 additions & 1 deletion mmcv/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
rel_roi_point_to_rel_img_point)
from .points_in_boxes import (points_in_boxes_all, points_in_boxes_cpu,
points_in_boxes_part)
from .points_in_polygons import points_in_polygons
from .points_sampler import PointsSampler
from .psa_mask import PSAMask
from .riroi_align_rotated import RiRoIAlignRotated, riroi_align_rotated
Expand Down Expand Up @@ -80,5 +81,6 @@
'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation',
'boxes_iou_bev', 'nms_bev', 'nms_normal_bev', 'Voxelization',
'voxelization', 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d',
'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all'
'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all',
'points_in_polygons'
]
79 changes: 79 additions & 0 deletions mmcv/ops/csrc/common/cuda/points_in_polygons_cuda_kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef POINTS_IN_POLYGONS_CUDA_KERNEL_CUH
#define POINTS_IN_POLYGONS_CUDA_KERNEL_CUH

#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif

struct point {
float x, y;
};

template <typename scalar_t>
__global__ void points_in_polygons_forward_cuda_kernel(
const int nthreads, const scalar_t *vertex1, const scalar_t *vertex2,
const int rows, const int cols, scalar_t *inside_flag) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int row = index / cols;
int col = index % cols;

const scalar_t *offset_vertex1 = vertex1 + row * 2;
const scalar_t *offset_vertex2 = vertex2 + col * 8;

point point_[1];
point polygon[4];

point_[0].x = offset_vertex1[0];
point_[0].y = offset_vertex1[1];

polygon[0].x = offset_vertex2[0];
polygon[0].y = offset_vertex2[1];
polygon[1].x = offset_vertex2[2];
polygon[1].y = offset_vertex2[3];
polygon[2].x = offset_vertex2[4];
polygon[2].y = offset_vertex2[5];
polygon[3].x = offset_vertex2[6];
polygon[3].y = offset_vertex2[7];

int nCross = 0;
int i, j;
float sx, sy, tx, ty, px, py, x;
for (i = 0, j = 3; i < 4; j = i, i++) {
sx = polygon[i].x;
sy = polygon[i].y;
tx = polygon[j].x;
ty = polygon[j].y;

px = point_[0].x;
py = point_[0].y;

if (py < min(sy, ty)) continue;
if (py > max(sy, ty)) continue;

if ((sx == px && sy == py) || (tx == px && ty == py)) {
break;
} else {
if ((sy < py && ty >= py) || (sy >= py && ty < py)) {
x = sx + (py - sy) * (tx - sx) / (ty - sy);
if (x == px) {
break;
}
if (x > px) {
nCross++;
}
}
}
}
if (nCross % 2 == 1) {
inside_flag[index] = 1.0;
} else {
inside_flag[index] = 0.0;
}
return;
}
}

#endif // POINTS_IN_POLYGONS_CUDA_KERNEL_CUH
19 changes: 19 additions & 0 deletions mmcv/ops/csrc/pytorch/cuda/cudabind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1481,3 +1481,22 @@ REGISTER_DEVICE_IMPL(rotated_feature_align_forward_impl, CUDA,
rotated_feature_align_forward_cuda);
REGISTER_DEVICE_IMPL(rotated_feature_align_backward_impl, CUDA,
rotated_feature_align_backward_cuda);

void PointsInPolygonsForwardCUDAKernelLauncher(const at::Tensor points,
const at::Tensor polygons,
const int rows, const int cols,
at::Tensor output);

void points_in_polygons_forward_cuda(const Tensor points, const Tensor polygons,
Tensor output, const int rows,
const int cols) {
PointsInPolygonsForwardCUDAKernelLauncher(points, polygons, rows, cols,
output);
};

void points_in_polygons_forward_impl(const Tensor points, const Tensor polygons,
Tensor output, const int rows,
const int cols);

REGISTER_DEVICE_IMPL(points_in_polygons_forward_impl, CUDA,
points_in_polygons_forward_cuda);
28 changes: 28 additions & 0 deletions mmcv/ops/csrc/pytorch/cuda/points_in_polygons_cuda.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (c) OpenMMLab. All rights reserved
// Modified from
// https://github.com/ming71/CUDA/blob/master/point_justify/points_justify_kernel.cu

#include <stdio.h>

#include "points_in_polygons_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"

void PointsInPolygonsForwardCUDAKernelLauncher(const at::Tensor points,
const at::Tensor polygons,
const int rows, const int cols,
at::Tensor output) {
const int output_size = rows * cols;
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
points.scalar_type(), "points_in_polygons_forward_cuda_kernel", ([&] {
const scalar_t *vertex1 = points.data_ptr<scalar_t>();
const scalar_t *vertex2 = polygons.data_ptr<scalar_t>();
scalar_t *inside_flag = output.data_ptr<scalar_t>();

points_in_polygons_forward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, vertex1, vertex2, rows, cols, inside_flag);
}));
AT_CUDA_CHECK(cudaGetLastError());
}
15 changes: 15 additions & 0 deletions mmcv/ops/csrc/pytorch/points_in_polygons.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"

void points_in_polygons_forward_impl(const Tensor points, const Tensor polygons,
Tensor output, const int rows,
const int cols) {
DISPATCH_DEVICE_IMPL(points_in_polygons_forward_impl, points, polygons,
output, rows, cols);
}

void points_in_polygons_forward(Tensor points, Tensor polygons, Tensor output) {
int rows = points.size(0);
int cols = polygons.size(0);
points_in_polygons_forward_impl(points, polygons, output, rows, cols);
}
5 changes: 5 additions & 0 deletions mmcv/ops/csrc/pytorch/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,8 @@ void riroi_align_rotated_backward(Tensor top_grad, Tensor rois,
int num_samples, int num_orientations,
bool clockwise);

void points_in_polygons_forward(Tensor points, Tensor polygons, Tensor output);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"),
py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"),
Expand Down Expand Up @@ -726,4 +728,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("pooled_width"), py::arg("spatial_scale"),
py::arg("num_samples"), py::arg("num_orientations"),
py::arg("clockwise"));
m.def("points_in_polygons_forward", &points_in_polygons_forward,
"points_in_polygons_forward", py::arg("points"), py::arg("polygons"),
py::arg("output"));
}
37 changes: 37 additions & 0 deletions mmcv/ops/points_in_polygons.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch

from ..utils import ext_loader

ext_module = ext_loader.load_ext('_ext', ['points_in_polygons_forward'])


def points_in_polygons(points, polygons):
"""Judging whether points are inside polygons, which is used in the ATSS
assignment for the rotated boxes.
It should be noted that when the point is just at the polygon boundary, the
judgment will be inaccurate, but the effect on assignment is limited.
Args:
points (torch.Tensor): It has shape (B, 2), indicating (x, y).
M means the number of predicted points.
polygons (torch.Tensor): It has shape (M, 8), indicating
(x1, y1, x2, y2, x3, y3, x4, y4). M means the number of
ground truth polygons.
Returns:
torch.Tensor: Return the result with the shape of (B, M),
1 indicates that the point is inside the polygon,
0 indicates that the point is outside the polygon.
"""
assert points.shape[1] == 2, \
'points dimension should be 2, ' \
f'but got unexpected shape {points.shape[1]}'
assert polygons.shape[1] == 8, \
'polygons dimension should be 8, ' \
f'but got unexpected shape {polygons.shape[1]}'
output = torch.full([points.shape[0], polygons.shape[0]],
0.).cuda().float()
ext_module.points_in_polygons_forward(points.contiguous(),
polygons.contiguous(), output)
return output
22 changes: 22 additions & 0 deletions tests/test_ops/test_points_in_polygons.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import numpy as np
import pytest
import torch

from mmcv.ops import points_in_polygons


@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_points_in_polygons():
points = np.array([[300., 300.], [400., 400.], [100., 100], [300, 250],
[100, 0]])
polygons = np.array([[200., 200., 400., 400., 500., 200., 400., 100.],
[400., 400., 500., 500., 600., 300., 500., 200.],
[300., 300., 600., 700., 700., 700., 700., 100.]])
expected_output = np.array([[0., 0., 0.], [0., 0., 1.], [0., 0., 0.],
[1., 0., 0.], [0., 0., 0.]])
points = torch.from_numpy(points).cuda().float()
polygons = torch.from_numpy(polygons).cuda().float()
expected_output = torch.from_numpy(expected_output).cuda().float()
assert torch.allclose(
points_in_polygons(points, polygons), expected_output, 1e-3)

0 comments on commit 304efbb

Please sign in to comment.