diff --git a/torchvision/csrc/ops/cuda/nms_kernel.cu b/torchvision/csrc/ops/cuda/nms_kernel.cu index f4be29e2535..9de6fa5bbf3 100644 --- a/torchvision/csrc/ops/cuda/nms_kernel.cu +++ b/torchvision/csrc/ops/cuda/nms_kernel.cu @@ -77,6 +77,48 @@ __global__ void nms_kernel_impl( } } +__global__ static void gather_keep_from_mask(bool *keep, + const unsigned long long *dev_mask, + const int n_boxes) { + // Taken and adapted from mmcv https://github.com/open-mmlab/mmcv/blob/03ce9208d18c0a63d7ffa087ea1c2f5661f2441a/mmcv/ops/csrc/common/cuda/nms_cuda_kernel.cuh#L76 + const int col_blocks = ceil_div(n_boxes, threadsPerBlock); + const int thread_id = threadIdx.x; + + // mark the bboxes which have been removed. + extern __shared__ unsigned long long removed[]; + + // initialize removed. + for (int i = thread_id; i < col_blocks; i += blockDim.x) { + removed[i] = 0; + } + __syncthreads(); + + for (int nblock = 0; nblock < col_blocks; nblock++) { + auto removed_val = removed[nblock]; + __syncthreads(); + const int i_offset = nblock * threadsPerBlock; + #pragma unroll + for (int inblock = 0; inblock < threadsPerBlock; inblock++) { + const int i = i_offset + inblock; + if (i >= n_boxes) break; + // select a candidate, check if it should kept. + if (!(removed_val & (1ULL << inblock))) { + if (thread_id == 0) { + // mark the output. + keep[i] = true; + } + auto p = dev_mask + i * col_blocks; + // remove all bboxes which overlap the candidate. + for (int j = thread_id; j < col_blocks; j += blockDim.x) { + if (j >= nblock) removed[j] |= p[j]; + } + __syncthreads(); + removed_val = removed[nblock]; + } + } + } +} + at::Tensor nms_kernel( const at::Tensor& dets, const at::Tensor& scores, @@ -133,35 +175,23 @@ at::Tensor nms_kernel( (unsigned long long*)mask.data_ptr()); }); - at::Tensor mask_cpu = mask.to(at::kCPU); - unsigned long long* mask_host = - (unsigned long long*)mask_cpu.data_ptr(); - - std::vector remv(col_blocks); - memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks); - - at::Tensor keep = - at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU)); - int64_t* keep_out = keep.data_ptr(); - - int num_to_keep = 0; - for (int i = 0; i < dets_num; i++) { - int nblock = i / threadsPerBlock; - int inblock = i % threadsPerBlock; - - if (!(remv[nblock] & (1ULL << inblock))) { - keep_out[num_to_keep++] = i; - unsigned long long* p = mask_host + i * col_blocks; - for (int j = nblock; j < col_blocks; j++) { - remv[j] |= p[j]; - } - } - } + at::Tensor keep = at::zeros( + {dets_num}, + dets.options().dtype(at::kBool).device(at::kCUDA) + ); + + // Unwrap the mask to fill keep with proper values + // Keeping this unwrap on cuda instead of applying iterative for loops on cpu + // prevents the device -> cpu -> device transfer that could be bottleneck for + // large number of boxes. + // See https://github.com/pytorch/vision/issues/8713 for more details + gather_keep_from_mask<<<1, min(col_blocks, threadsPerBlock), + col_blocks * sizeof(unsigned long long), stream>>>( + keep.data_ptr(), (unsigned long long*)mask.data_ptr(), + dets_num); AT_CUDA_CHECK(cudaGetLastError()); - return order_t.index( - {keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep) - .to(order_t.device(), keep.scalar_type())}); + return order_t.masked_select(keep); } } // namespace