diff --git a/mmcv/ops/csrc/pytorch/npu/bbox_overlaps_npu.cpp b/mmcv/ops/csrc/pytorch/npu/bbox_overlaps_npu.cpp index 59d0f1ef77..fbb979ff02 100644 --- a/mmcv/ops/csrc/pytorch/npu/bbox_overlaps_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/bbox_overlaps_npu.cpp @@ -12,11 +12,17 @@ void bbox_overlaps_npu(const Tensor bboxes1, const Tensor bboxes2, Tensor ious, if (mode == 1) { modeStr = "iof"; } + bool swap_flag = false; at::Tensor bboxesFP32 = bboxes2; at::Tensor gtboxesFP32 = bboxes1; + if (bboxes2.size(0) < bboxes1.size(0)) { + swap_flag = true; + bboxesFP32 = bboxes1; + gtboxesFP32 = bboxes2; + } if (bboxes2.scalar_type() != at::ScalarType::Float) { - bboxesFP32 = NPUNativeFunctions::npu_dtype_cast(bboxes2, at::kFloat); - gtboxesFP32 = NPUNativeFunctions::npu_dtype_cast(bboxes1, at::kFloat); + bboxesFP32 = NPUNativeFunctions::npu_dtype_cast(bboxesFP32, at::kFloat); + gtboxesFP32 = NPUNativeFunctions::npu_dtype_cast(gtboxesFP32, at::kFloat); } c10::SmallVector iousSize = {gtboxesFP32.size(0), bboxesFP32.size(0)}; @@ -38,6 +44,7 @@ void bbox_overlaps_npu(const Tensor bboxes1, const Tensor bboxes2, Tensor ious, if (bboxes2.scalar_type() != at::ScalarType::Float) { iousFP32 = NPUNativeFunctions::npu_dtype_cast(iousFP32, at::kHalf); } + iousFP32 = swap_flag ? iousFP32.transpose(0, 1) : iousFP32; ious.copy_(iousFP32); }