-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Add points_in_polygons CUDA op for rotated detection. (#1600)
- Loading branch information
Showing
10 changed files
with
210 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
79 changes: 79 additions & 0 deletions
79
mmcv/ops/csrc/common/cuda/points_in_polygons_cuda_kernel.cuh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved | ||
#ifndef POINTS_IN_POLYGONS_CUDA_KERNEL_CUH | ||
#define POINTS_IN_POLYGONS_CUDA_KERNEL_CUH | ||
|
||
#ifdef MMCV_USE_PARROTS | ||
#include "parrots_cuda_helper.hpp" | ||
#else | ||
#include "pytorch_cuda_helper.hpp" | ||
#endif | ||
|
||
struct point { | ||
float x, y; | ||
}; | ||
|
||
template <typename scalar_t> | ||
__global__ void points_in_polygons_forward_cuda_kernel( | ||
const int nthreads, const scalar_t *vertex1, const scalar_t *vertex2, | ||
const int rows, const int cols, scalar_t *inside_flag) { | ||
CUDA_1D_KERNEL_LOOP(index, nthreads) { | ||
int row = index / cols; | ||
int col = index % cols; | ||
|
||
const scalar_t *offset_vertex1 = vertex1 + row * 2; | ||
const scalar_t *offset_vertex2 = vertex2 + col * 8; | ||
|
||
point point_[1]; | ||
point polygon[4]; | ||
|
||
point_[0].x = offset_vertex1[0]; | ||
point_[0].y = offset_vertex1[1]; | ||
|
||
polygon[0].x = offset_vertex2[0]; | ||
polygon[0].y = offset_vertex2[1]; | ||
polygon[1].x = offset_vertex2[2]; | ||
polygon[1].y = offset_vertex2[3]; | ||
polygon[2].x = offset_vertex2[4]; | ||
polygon[2].y = offset_vertex2[5]; | ||
polygon[3].x = offset_vertex2[6]; | ||
polygon[3].y = offset_vertex2[7]; | ||
|
||
int nCross = 0; | ||
int i, j; | ||
float sx, sy, tx, ty, px, py, x; | ||
for (i = 0, j = 3; i < 4; j = i, i++) { | ||
sx = polygon[i].x; | ||
sy = polygon[i].y; | ||
tx = polygon[j].x; | ||
ty = polygon[j].y; | ||
|
||
px = point_[0].x; | ||
py = point_[0].y; | ||
|
||
if (py < min(sy, ty)) continue; | ||
if (py > max(sy, ty)) continue; | ||
|
||
if ((sx == px && sy == py) || (tx == px && ty == py)) { | ||
break; | ||
} else { | ||
if ((sy < py && ty >= py) || (sy >= py && ty < py)) { | ||
x = sx + (py - sy) * (tx - sx) / (ty - sy); | ||
if (x == px) { | ||
break; | ||
} | ||
if (x > px) { | ||
nCross++; | ||
} | ||
} | ||
} | ||
} | ||
if (nCross % 2 == 1) { | ||
inside_flag[index] = 1.0; | ||
} else { | ||
inside_flag[index] = 0.0; | ||
} | ||
return; | ||
} | ||
} | ||
|
||
#endif // POINTS_IN_POLYGONS_CUDA_KERNEL_CUH |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved | ||
// Modified from | ||
// https://github.com/ming71/CUDA/blob/master/point_justify/points_justify_kernel.cu | ||
|
||
#include <stdio.h> | ||
|
||
#include "points_in_polygons_cuda_kernel.cuh" | ||
#include "pytorch_cuda_helper.hpp" | ||
|
||
void PointsInPolygonsForwardCUDAKernelLauncher(const at::Tensor points, | ||
const at::Tensor polygons, | ||
const int rows, const int cols, | ||
at::Tensor output) { | ||
const int output_size = rows * cols; | ||
at::cuda::CUDAGuard device_guard(points.device()); | ||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
AT_DISPATCH_FLOATING_TYPES_AND_HALF( | ||
points.scalar_type(), "points_in_polygons_forward_cuda_kernel", ([&] { | ||
const scalar_t *vertex1 = points.data_ptr<scalar_t>(); | ||
const scalar_t *vertex2 = polygons.data_ptr<scalar_t>(); | ||
scalar_t *inside_flag = output.data_ptr<scalar_t>(); | ||
|
||
points_in_polygons_forward_cuda_kernel<scalar_t> | ||
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>( | ||
output_size, vertex1, vertex2, rows, cols, inside_flag); | ||
})); | ||
AT_CUDA_CHECK(cudaGetLastError()); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#include "pytorch_cpp_helper.hpp" | ||
#include "pytorch_device_registry.hpp" | ||
|
||
void points_in_polygons_forward_impl(const Tensor points, const Tensor polygons, | ||
Tensor output, const int rows, | ||
const int cols) { | ||
DISPATCH_DEVICE_IMPL(points_in_polygons_forward_impl, points, polygons, | ||
output, rows, cols); | ||
} | ||
|
||
void points_in_polygons_forward(Tensor points, Tensor polygons, Tensor output) { | ||
int rows = points.size(0); | ||
int cols = polygons.size(0); | ||
points_in_polygons_forward_impl(points, polygons, output, rows, cols); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import torch | ||
|
||
from ..utils import ext_loader | ||
|
||
ext_module = ext_loader.load_ext('_ext', ['points_in_polygons_forward']) | ||
|
||
|
||
def points_in_polygons(points, polygons): | ||
"""Judging whether points are inside polygons, which is used in the ATSS | ||
assignment for the rotated boxes. | ||
It should be noted that when the point is just at the polygon boundary, the | ||
judgment will be inaccurate, but the effect on assignment is limited. | ||
Args: | ||
points (torch.Tensor): It has shape (B, 2), indicating (x, y). | ||
M means the number of predicted points. | ||
polygons (torch.Tensor): It has shape (M, 8), indicating | ||
(x1, y1, x2, y2, x3, y3, x4, y4). M means the number of | ||
ground truth polygons. | ||
Returns: | ||
torch.Tensor: Return the result with the shape of (B, M), | ||
1 indicates that the point is inside the polygon, | ||
0 indicates that the point is outside the polygon. | ||
""" | ||
assert points.shape[1] == 2, \ | ||
'points dimension should be 2, ' \ | ||
f'but got unexpected shape {points.shape[1]}' | ||
assert polygons.shape[1] == 8, \ | ||
'polygons dimension should be 8, ' \ | ||
f'but got unexpected shape {polygons.shape[1]}' | ||
output = torch.full([points.shape[0], polygons.shape[0]], | ||
0.).cuda().float() | ||
ext_module.points_in_polygons_forward(points.contiguous(), | ||
polygons.contiguous(), output) | ||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import numpy as np | ||
import pytest | ||
import torch | ||
|
||
from mmcv.ops import points_in_polygons | ||
|
||
|
||
@pytest.mark.skipif( | ||
not torch.cuda.is_available(), reason='requires CUDA support') | ||
def test_points_in_polygons(): | ||
points = np.array([[300., 300.], [400., 400.], [100., 100], [300, 250], | ||
[100, 0]]) | ||
polygons = np.array([[200., 200., 400., 400., 500., 200., 400., 100.], | ||
[400., 400., 500., 500., 600., 300., 500., 200.], | ||
[300., 300., 600., 700., 700., 700., 700., 100.]]) | ||
expected_output = np.array([[0., 0., 0.], [0., 0., 1.], [0., 0., 0.], | ||
[1., 0., 0.], [0., 0., 0.]]) | ||
points = torch.from_numpy(points).cuda().float() | ||
polygons = torch.from_numpy(polygons).cuda().float() | ||
expected_output = torch.from_numpy(expected_output).cuda().float() | ||
assert torch.allclose( | ||
points_in_polygons(points, polygons), expected_output, 1e-3) |