Skip to content

Commit

Permalink
[Fix] Update test data for test_iou3d (#1427)
Browse files Browse the repository at this point in the history
* Update test data for test_iou3d

* delete blank lines

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
  • Loading branch information
q.yao and zhouzaida committed Nov 3, 2021
1 parent 36e7aa1 commit 48c7b57
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 47 deletions.
58 changes: 22 additions & 36 deletions mmcv/ops/csrc/pytorch/iou3d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,25 +134,18 @@ int iou3d_nms_forward(Tensor boxes, Tensor keep, float nms_overlap_thresh) {

const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);

unsigned long long *mask_data = NULL;
CHECK_ERROR(
cudaMalloc((void **)&mask_data,
boxes_num * col_blocks * sizeof(unsigned long long)));
Tensor mask =
at::empty({boxes_num, col_blocks}, boxes.options().dtype(at::kLong));
unsigned long long *mask_data =
(unsigned long long *)mask.data_ptr<int64_t>();
iou3d_nms_forward_cuda(boxes, mask_data, boxes_num, nms_overlap_thresh);

// unsigned long long mask_cpu[boxes_num * col_blocks];
// unsigned long long *mask_cpu = new unsigned long long [boxes_num *
// col_blocks];
std::vector<unsigned long long> mask_cpu(boxes_num * col_blocks);
at::Tensor mask_cpu = mask.to(at::kCPU);
unsigned long long *mask_host =
(unsigned long long *)mask_cpu.data_ptr<int64_t>();

// printf("boxes_num=%d, col_blocks=%d\n", boxes_num, col_blocks);
CHECK_ERROR(cudaMemcpy(&mask_cpu[0], mask_data,
boxes_num * col_blocks * sizeof(unsigned long long),
cudaMemcpyDeviceToHost));

cudaFree(mask_data);

unsigned long long *remv_cpu = new unsigned long long[col_blocks]();
std::vector<unsigned long long> remv_cpu(col_blocks);
memset(&remv_cpu[0], 0, sizeof(unsigned long long) * col_blocks);

int num_to_keep = 0;

Expand All @@ -162,13 +155,13 @@ int iou3d_nms_forward(Tensor boxes, Tensor keep, float nms_overlap_thresh) {

if (!(remv_cpu[nblock] & (1ULL << inblock))) {
keep_data[num_to_keep++] = i;
unsigned long long *p = &mask_cpu[0] + i * col_blocks;
unsigned long long *p = &mask_host[0] + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv_cpu[j] |= p[j];
}
}
}
delete[] remv_cpu;

if (cudaSuccess != cudaGetLastError()) printf("Error!\n");

return num_to_keep;
Expand Down Expand Up @@ -196,26 +189,19 @@ int iou3d_nms_normal_forward(Tensor boxes, Tensor keep,

const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);

unsigned long long *mask_data = NULL;
CHECK_ERROR(
cudaMalloc((void **)&mask_data,
boxes_num * col_blocks * sizeof(unsigned long long)));
Tensor mask =
at::empty({boxes_num, col_blocks}, boxes.options().dtype(at::kLong));
unsigned long long *mask_data =
(unsigned long long *)mask.data_ptr<int64_t>();
iou3d_nms_normal_forward_cuda(boxes, mask_data, boxes_num,
nms_overlap_thresh);

// unsigned long long mask_cpu[boxes_num * col_blocks];
// unsigned long long *mask_cpu = new unsigned long long [boxes_num *
// col_blocks];
std::vector<unsigned long long> mask_cpu(boxes_num * col_blocks);

CHECK_ERROR(cudaMemcpy(&mask_cpu[0], mask_data,
boxes_num * col_blocks * sizeof(unsigned long long),
cudaMemcpyDeviceToHost));

cudaFree(mask_data);

unsigned long long *remv_cpu = new unsigned long long[col_blocks]();
at::Tensor mask_cpu = mask.to(at::kCPU);
unsigned long long *mask_host =
(unsigned long long *)mask_cpu.data_ptr<int64_t>();

std::vector<unsigned long long> remv_cpu(col_blocks);
memset(&remv_cpu[0], 0, sizeof(unsigned long long) * col_blocks);
int num_to_keep = 0;

for (int i = 0; i < boxes_num; i++) {
Expand All @@ -224,13 +210,13 @@ int iou3d_nms_normal_forward(Tensor boxes, Tensor keep,

if (!(remv_cpu[nblock] & (1ULL << inblock))) {
keep_data[num_to_keep++] = i;
unsigned long long *p = &mask_cpu[0] + i * col_blocks;
unsigned long long *p = &mask_host[0] + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv_cpu[j] |= p[j];
}
}
}
delete[] remv_cpu;

if (cudaSuccess != cudaGetLastError()) printf("Error!\n");

return num_to_keep;
Expand Down
2 changes: 2 additions & 0 deletions mmcv/ops/iou3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def nms_bev(boxes, scores, thresh, pre_max_size=None, post_max_size=None):
Returns:
torch.Tensor: Indexes after NMS.
"""
assert boxes.size(1) == 5, 'Input boxes shape should be [N, 5]'
order = scores.sort(0, descending=True)[1]

if pre_max_size is not None:
Expand Down Expand Up @@ -74,6 +75,7 @@ def nms_normal_bev(boxes, scores, thresh):
Returns:
torch.Tensor: Remaining indices with scores in descending order.
"""
assert boxes.shape[1] == 5, 'Input boxes shape should be [N, 5]'
order = scores.sort(0, descending=True)[1]

boxes = boxes[order].contiguous()
Expand Down
24 changes: 13 additions & 11 deletions tests/test_ops/test_iou3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,29 +30,31 @@ def test_boxes_iou_bev():

@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_nms_gpu():
np_boxes = np.array([[6.0, 3.0, 8.0, 7.0], [3.0, 6.0, 9.0, 11.0],
[3.0, 7.0, 10.0, 12.0], [1.0, 4.0, 13.0, 7.0]],
dtype=np.float32)
def test_nms_bev():
np_boxes = np.array(
[[6.0, 3.0, 8.0, 7.0, 2.0], [3.0, 6.0, 9.0, 11.0, 1.0],
[3.0, 7.0, 10.0, 12.0, 1.0], [1.0, 4.0, 13.0, 7.0, 3.0]],
dtype=np.float32)
np_scores = np.array([0.6, 0.9, 0.7, 0.2], dtype=np.float32)
np_inds = np.array([1, 0, 3])
boxes = torch.from_numpy(np_boxes)
scores = torch.from_numpy(np_scores)
inds = nms_bev(boxes.cuda(), scores.cuda(), thresh=0.3)

assert np.allclose(inds.cpu().numpy(), np_inds) # test gpu
assert np.allclose(inds.cpu().numpy(), np_inds)


@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_nms_normal_gpu():
np_boxes = np.array([[6.0, 3.0, 8.0, 7.0], [3.0, 6.0, 9.0, 11.0],
[3.0, 7.0, 10.0, 12.0], [1.0, 4.0, 13.0, 7.0]],
dtype=np.float32)
def test_nms_normal_bev():
np_boxes = np.array(
[[6.0, 3.0, 8.0, 7.0, 2.0], [3.0, 6.0, 9.0, 11.0, 1.0],
[3.0, 7.0, 10.0, 12.0, 1.0], [1.0, 4.0, 13.0, 7.0, 3.0]],
dtype=np.float32)
np_scores = np.array([0.6, 0.9, 0.7, 0.2], dtype=np.float32)
np_inds = np.array([1, 2, 0, 3])
np_inds = np.array([1, 0, 3])
boxes = torch.from_numpy(np_boxes)
scores = torch.from_numpy(np_scores)
inds = nms_normal_bev(boxes.cuda(), scores.cuda(), thresh=0.3)

assert np.allclose(inds.cpu().numpy(), np_inds) # test gpu
assert np.allclose(inds.cpu().numpy(), np_inds)

0 comments on commit 48c7b57

Please sign in to comment.