From 59b8a468ef8905b7728f98acee3dba03290f6d21 Mon Sep 17 00:00:00 2001 From: momo609 <963372609@qq.com> Date: Mon, 25 Sep 2023 09:19:12 +0800 Subject: [PATCH 1/4] mmcv 1.x adpater torch_npu --- mmcv/ops/csrc/common/pytorch_npu_helper.hpp | 14 +++++++- mmcv/ops/csrc/pytorch/nms_rotated.cpp | 7 ++-- .../csrc/pytorch/npu/bbox_overlaps_npu.cpp | 12 +++---- mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp | 34 +++++++------------ .../pytorch/npu/fused_bias_leakyrelu_npu.cpp | 3 +- mmcv/ops/csrc/pytorch/npu/nms_npu.cpp | 26 ++++++-------- mmcv/ops/csrc/pytorch/npu/nms_rotated_npu.cpp | 14 ++++---- mmcv/ops/csrc/pytorch/npu/roi_align_npu.cpp | 2 +- mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp | 8 ++--- .../ops/csrc/pytorch/npu/voxelization_npu.cpp | 3 +- setup.py | 6 +++- 11 files changed, 64 insertions(+), 65 deletions(-) diff --git a/mmcv/ops/csrc/common/pytorch_npu_helper.hpp b/mmcv/ops/csrc/common/pytorch_npu_helper.hpp index 88607d23b3..fe0cdce545 100644 --- a/mmcv/ops/csrc/common/pytorch_npu_helper.hpp +++ b/mmcv/ops/csrc/common/pytorch_npu_helper.hpp @@ -18,7 +18,7 @@ #ifndef PYTORCH_NPU_HELPER_HPP_ #define PYTORCH_NPU_HELPER_HPP_ -#include +#include #include #include @@ -27,9 +27,21 @@ #define NPU_NAME_SPACE at_npu::native +#ifdef MMCV_WITH_XLA #define REGISTER_NPU_IMPL(key, value) REGISTER_DEVICE_IMPL(key, XLA, value) +#else +#define REGISTER_NPU_IMPL(key, value) \ + REGISTER_DEVICE_IMPL(key, PrivateUse1, value) +#endif +#ifdef MMCV_WITH_XLA #define CHECK_NPU(x) \ TORCH_CHECK(x.device().type() == at::kXLA, #x " must be a NPU tensor") +#else +#define CHECK_NPU(x) \ + TORCH_CHECK(x.device().type() == at::kPrivateUse1, #x \ + " must be a NPU " \ + "tensor") +#endif #endif // PYTORCH_NPU_HELPER_HPP_ diff --git a/mmcv/ops/csrc/pytorch/nms_rotated.cpp b/mmcv/ops/csrc/pytorch/nms_rotated.cpp index 1d49c37dd6..b7f485fd15 100644 --- a/mmcv/ops/csrc/pytorch/nms_rotated.cpp +++ b/mmcv/ops/csrc/pytorch/nms_rotated.cpp @@ -36,11 +36,12 @@ Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order, #else AT_ERROR("Not compiled with GPU support"); #endif +#ifdef MMCV_WITH_XLA } else if (dets.device().type() == at::kXLA) { -#ifdef MMCV_WITH_NPU +#endif +#ifdef MMCV_WITH_KPRIVATE + } else if (dets.device().type() == at::kPrivateUse1) { return nms_rotated_npu(dets, scores, labels, iou_threshold); -#else - AT_ERROR("Not compiled with NPU support"); #endif #ifdef MMCV_WITH_MLU } else if (dets.device().type() == at::kMLU) { diff --git a/mmcv/ops/csrc/pytorch/npu/bbox_overlaps_npu.cpp b/mmcv/ops/csrc/pytorch/npu/bbox_overlaps_npu.cpp index fbb979ff02..ed04622af6 100644 --- a/mmcv/ops/csrc/pytorch/npu/bbox_overlaps_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/bbox_overlaps_npu.cpp @@ -20,16 +20,16 @@ void bbox_overlaps_npu(const Tensor bboxes1, const Tensor bboxes2, Tensor ious, bboxesFP32 = bboxes1; gtboxesFP32 = bboxes2; } - if (bboxes2.scalar_type() != at::ScalarType::Float) { - bboxesFP32 = NPUNativeFunctions::npu_dtype_cast(bboxesFP32, at::kFloat); - gtboxesFP32 = NPUNativeFunctions::npu_dtype_cast(gtboxesFP32, at::kFloat); + if (bboxes2.scalar_type() != at::kFloat) { + bboxesFP32 = bboxesFP32.to(at::kFloat); + gtboxesFP32 = gtboxesFP32.to(at::kFloat); } c10::SmallVector iousSize = {gtboxesFP32.size(0), bboxesFP32.size(0)}; if (aligned) { iousSize = {gtboxesFP32.size(0), 1}; } - at::Tensor iousFP32 = OpPreparation::ApplyTensor(bboxesFP32, iousSize); + at::Tensor iousFP32 = at::empty(iousSize, bboxesFP32.options()); bboxesFP32 = aligned ? bboxesFP32.transpose(0, 1) : bboxesFP32; gtboxesFP32 = aligned ? gtboxesFP32.transpose(0, 1) : gtboxesFP32; OpCommand cmd; @@ -41,8 +41,8 @@ void bbox_overlaps_npu(const Tensor bboxes1, const Tensor bboxes2, Tensor ious, .Attr("eps", (float)offset) .Attr("aligned", aligned) .Run(); - if (bboxes2.scalar_type() != at::ScalarType::Float) { - iousFP32 = NPUNativeFunctions::npu_dtype_cast(iousFP32, at::kHalf); + if (bboxes2.scalar_type() != at::kFloat) { + iousFP32 = iousFP32.to(at::kHalf); } iousFP32 = swap_flag ? iousFP32.transpose(0, 1) : iousFP32; ious.copy_(iousFP32); diff --git a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp index c949bf9539..a6e08f1067 100644 --- a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp @@ -12,15 +12,13 @@ void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, target_y = at::mul(target_y, -1.0); target_y = at::add(target_y, 1.0); } else { - target_y = at_npu::native::NPUNativeFunctions::one_hot(target, n_class); + target_y = at::one_hot(target, n_class); } - target_y = - at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); + target_y = target_y.to(at::kInt); int64_t weight_size = weight.size(0); at::Tensor weight_y = at::ones_like(input); if (weight_size > 0) { - weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, - input.sizes()); + weight_y = at::broadcast_to(weight, input.sizes()); } OpCommand cmd; string reduction = "none"; @@ -46,12 +44,11 @@ void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, if (n_class == 1) { target_y = at::reshape(target, input.sizes()); } else { - target_y = at_npu::native::NPUNativeFunctions::one_hot(target, n_class); + target_y = at::one_hot(target, n_class); target_y = at::mul(target_y, -1.0); target_y = at::add(target_y, 1.0); } - target_y = - at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); + target_y = target_y.to(at::kInt); at::Tensor grad_up = at::ones_like(input); int64_t weight_size = weight.size(0); at::Tensor weight_y = at::ones_like(input); @@ -80,15 +77,12 @@ void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target, void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, Tensor output, float gamma, float alpha) { int64_t n_class = input.size(1); - at::Tensor target_y = - at_npu::native::NPUNativeFunctions::one_hot(target, n_class); - target_y = - at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); + at::Tensor target_y = at::one_hot(target, n_class); + target_y = target_y.to(at::kInt); int64_t weight_size = weight.size(0); at::Tensor weight_y = at::ones_like(input); if (weight_size > 0) { - weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, - input.sizes()); + weight_y = at::broadcast_to(weight, input.sizes()); } at::Tensor op_output = at::ones_like(input); OpCommand cmd; @@ -107,8 +101,7 @@ void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, c10::SmallVector sizes = {n_batch, 1}; at::IntArrayRef offset = at::IntArrayRef(offsets); at::IntArrayRef size = at::IntArrayRef(sizes); - at_npu::native::NPUNativeFunctions::npu_slice_out(op_output, offset, size, - output); + at_npu::native::custom_ops::npu_slice_out(op_output, offset, size, output); } void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, @@ -119,16 +112,13 @@ void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, Tensor buff, Tensor grad_input, float gamma, float alpha) { int64_t n_class = input.size(1); - at::Tensor target_y = - at_npu::native::NPUNativeFunctions::one_hot(target, n_class); - target_y = - at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); + at::Tensor target_y = at::one_hot(target, n_class); + target_y = target_y.to(at::kInt); at::Tensor grad_up = at::ones_like(input); int64_t weight_size = weight.size(0); at::Tensor weight_y = at::ones_like(input); if (weight_size > 0) { - weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, - input.sizes()); + weight_y = at::broadcast_to(weight, input.sizes()); } OpCommand cmd; string reduction = "none"; diff --git a/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp b/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp index cd052b5868..4fc168094e 100644 --- a/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp @@ -25,8 +25,7 @@ Tensor fused_bias_leakyrelu_npu(const Tensor &input, const Tensor &bias, } } at::Tensor bias_tmp = at::reshape(bias, input_size_tmp); - at::Tensor bias_ = at_npu::native::NPUNativeFunctions::npu_broadcast( - bias_tmp, input.sizes()); + at::Tensor bias_ = at::broadcast_to(bias_tmp, input.sizes()); OpCommand cmd; cmd.Name("FusedBiasLeakyRelu") .Input(input) diff --git a/mmcv/ops/csrc/pytorch/npu/nms_npu.cpp b/mmcv/ops/csrc/pytorch/npu/nms_npu.cpp index 0a1f997a27..2d9ee8632e 100644 --- a/mmcv/ops/csrc/pytorch/npu/nms_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/nms_npu.cpp @@ -7,20 +7,15 @@ Tensor nms_npu(Tensor boxes, Tensor scores, float iou_threshold, int offset) { TORCH_CHECK((boxes.scalar_type() == at::ScalarType::Float), "The type of boxes tensor passed in nms_npu should be float"); int64_t offset_64 = offset; - at::Tensor iou_threshold_y = at_npu::native::OpPreparation::ApplyTensor( - {}, boxes.options().dtype(at::kFloat), boxes) - .fill_(iou_threshold); + at::Tensor iou_threshold_y = + at::empty({}, boxes.options().dtype(at::kFloat)).fill_(iou_threshold); at::Tensor scores_threshold_y = - at_npu::native::OpPreparation::ApplyTensor( - {}, boxes.options().dtype(at::kFloat), boxes) - .fill_(0); - at::Tensor max_outputsize_y = at_npu::native::OpPreparation::ApplyTensor( - {}, boxes.options().dtype(at::kInt), boxes) - .fill_(boxes.size(0)); + at::empty({}, boxes.options().dtype(at::kFloat)).fill_(0); + at::Tensor max_outputsize_y = + at::empty({}, boxes.options().dtype(at::kInt)).fill_(boxes.size(0)); c10::SmallVector outputsize = {boxes.size(0)}; - at::Tensor output = at_npu::native::OpPreparation::ApplyTensor( - outputsize, boxes.options().dtype(at::kInt), boxes) - .fill_(-1); + at::Tensor output = + at::empty(outputsize, boxes.options().dtype(at::kInt)).fill_(-1); OpCommand cmd; cmd.Name("NonMaxSuppressionV3") .Input(boxes) @@ -32,11 +27,10 @@ Tensor nms_npu(Tensor boxes, Tensor scores, float iou_threshold, int offset) { .Output(output) .Run(); auto outputsizeBool = at::gt(output, -1); - auto outputsizeInt = outputsizeBool.to(at::ScalarType::Int); - auto countLen = at::sum(outputsizeInt, at::ScalarType::Int); + auto outputsizeInt = outputsizeBool.to(at::kInt); + auto countLen = at::sum(outputsizeInt, at::kInt); at::Tensor actual_output = output.slice(0, 0, countLen.item().toLong()); - actual_output = at_npu::native::NPUNativeFunctions::npu_dtype_cast( - actual_output, at::kLong); + actual_output = actual_output.to(at::kLong); return actual_output; } diff --git a/mmcv/ops/csrc/pytorch/npu/nms_rotated_npu.cpp b/mmcv/ops/csrc/pytorch/npu/nms_rotated_npu.cpp index b82ae585cd..b7bdd90b23 100644 --- a/mmcv/ops/csrc/pytorch/npu/nms_rotated_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/nms_rotated_npu.cpp @@ -7,14 +7,14 @@ Tensor nms_rotated_npu(const Tensor dets, const Tensor scores, auto originDtype = dets.scalar_type(); at::Tensor detsCast = dets; at::Tensor scoresCast = scores; - if (originDtype != at::ScalarType::Float) { - detsCast = NPUNativeFunctions::npu_dtype_cast(dets, at::kFloat); - scoresCast = NPUNativeFunctions::npu_dtype_cast(scores, at::kFloat); + if (originDtype != at::kFloat) { + detsCast = detsCast.to(at::kFloat); + scoresCast = scoresCast.to(at::kFloat); } c10::SmallVector selectedIndexSize = {dets.size(0)}; - at::Tensor selectedBox = OpPreparation::ApplyTensor(dets); - at::Tensor selectedIndex = OpPreparation::ApplyTensor( - selectedIndexSize, dets.options().dtype(at::kInt), dets); + at::Tensor selectedBox = at::empty_like(dets); + at::Tensor selectedIndex = + at::empty(selectedIndexSize, dets.options().dtype(at::kInt)); c10::SmallVector output_sync_idx = {0, 1}; OpCommand cmd; @@ -27,6 +27,6 @@ Tensor nms_rotated_npu(const Tensor dets, const Tensor scores, .Output(selectedIndex) .Attr("iou_threshold", (float)iou_threshold) .Run(); - selectedIndex = NPUNativeFunctions::npu_dtype_cast(selectedIndex, at::kLong); + selectedIndex = selectedIndex.to(at::kLong); return selectedIndex; } diff --git a/mmcv/ops/csrc/pytorch/npu/roi_align_npu.cpp b/mmcv/ops/csrc/pytorch/npu/roi_align_npu.cpp index 31335557e4..465471976e 100644 --- a/mmcv/ops/csrc/pytorch/npu/roi_align_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/roi_align_npu.cpp @@ -37,7 +37,7 @@ void roi_align_backward_npu(Tensor grad_output, Tensor rois, Tensor argmax_y, int64_t sampling_ratio_64 = sampling_ratio; int64_t roi_end_mode = 0; c10::SmallVector xdiff_shape = - at_npu::native::array_to_small_vector(grad_input.sizes()); + array_to_small_vector(grad_input.sizes()); OpCommand cmd; cmd.Name("ROIAlignGrad") .Input(grad_output) diff --git a/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp b/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp index f428311fee..bf0eb18d2b 100644 --- a/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp @@ -9,8 +9,8 @@ void roi_pool_forward_npu(Tensor input, Tensor rois, Tensor output, int64_t pooled_height_64 = pooled_height; int64_t pooled_width_64 = pooled_width; int64_t pooled_channel = 1; - at::Tensor roi_actual_num = at_npu::native::OpPreparation::ApplyTensor( - {}, rois.options().dtype(at::kInt), rois); + at::Tensor roi_actual_num = + at::empty_like(rois, rois.options().dtype(at::kInt)); OpCommand cmd; cmd.Name("RoiPoolingWithArgMax") .Input(input) @@ -32,8 +32,8 @@ void roi_pool_backward_npu(Tensor grad_output, Tensor rois, Tensor argmax, int64_t pooled_height_64 = pooled_height; int64_t pooled_width_64 = pooled_width; int64_t pooled_channel = 1; - at::Tensor roi_actual_num = at_npu::native::OpPreparation::ApplyTensor( - {}, rois.options().dtype(at::kInt), rois); + at::Tensor roi_actual_num = + at::empty_like(rois, rois.options().dtype(at::kInt)); at::Tensor x = at::ones_like(grad_input); OpCommand cmd; cmd.Name("RoiPoolingGradWithArgMax") diff --git a/mmcv/ops/csrc/pytorch/npu/voxelization_npu.cpp b/mmcv/ops/csrc/pytorch/npu/voxelization_npu.cpp index 2b22646b9e..32b1a50cc1 100644 --- a/mmcv/ops/csrc/pytorch/npu/voxelization_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/voxelization_npu.cpp @@ -19,8 +19,7 @@ int hard_voxelize_forward_npu(const at::Tensor &points, at::Tensor &voxels, const int max_points, const int max_voxels, const int NDim = 3) { at::Tensor voxel_num_tmp = OpPreparation::ApplyTensor(points, {1}); - at::Tensor voxel_num = at_npu::native::NPUNativeFunctions::npu_dtype_cast( - voxel_num_tmp, at::kInt); + at::Tensor voxel_num = voxel_num_tmp.to(at::kInt); at::Tensor voxel_size_cpu = at::from_blob( const_cast(voxel_size.data()), {3}, dtype(at::kFloat)); diff --git a/setup.py b/setup.py index 8d87eaf0ec..9393d04264 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ import platform import re import warnings -from pkg_resources import DistributionNotFound, get_distribution +from pkg_resources import DistributionNotFound, get_distribution, parse_version from setuptools import find_packages, setup EXT_TYPE = '' @@ -428,6 +428,10 @@ def get_mluops_version(file_path): from torch_npu.utils.cpp_extension import NpuExtension define_macros += [('MMCV_WITH_NPU', None)] extension = NpuExtension + if parse_version(torch.__version__) <= parse_version('2.0.0'): + define_macros += [('MMCV_WITH_XLA', None)] + if parse_version(torch.__version__) > parse_version('2.0.0'): + define_macros += [('MMCV_WITH_KPRIVATE', None)] except Exception: raise ImportError('can not find any torch_npu') # src From 0f9e6bea79cea80a6c2f2e8a38aa0a69514e0db4 Mon Sep 17 00:00:00 2001 From: momo609 <963372609@qq.com> Date: Tue, 26 Sep 2023 09:20:52 +0800 Subject: [PATCH 2/4] fix nms_rotated --- mmcv/ops/csrc/pytorch/nms_rotated.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/mmcv/ops/csrc/pytorch/nms_rotated.cpp b/mmcv/ops/csrc/pytorch/nms_rotated.cpp index b7f485fd15..3b23b19309 100644 --- a/mmcv/ops/csrc/pytorch/nms_rotated.cpp +++ b/mmcv/ops/csrc/pytorch/nms_rotated.cpp @@ -38,6 +38,7 @@ Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order, #endif #ifdef MMCV_WITH_XLA } else if (dets.device().type() == at::kXLA) { + return nms_rotated_npu(dets, scores, labels, iou_threshold); #endif #ifdef MMCV_WITH_KPRIVATE } else if (dets.device().type() == at::kPrivateUse1) { From fa19ff20be6f6dbb08daba6b6686ec6ad2420c89 Mon Sep 17 00:00:00 2001 From: momo609 <963372609@qq.com> Date: Tue, 26 Sep 2023 10:24:15 +0800 Subject: [PATCH 3/4] fix nms_rotated --- mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp index a6e08f1067..b7c995a223 100644 --- a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp @@ -53,8 +53,7 @@ void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, int64_t weight_size = weight.size(0); at::Tensor weight_y = at::ones_like(input); if (weight_size > 0) { - weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, - input.sizes()); + weight_y = at::broadcast_to(weight, input.sizes()); } OpCommand cmd; string reduction = "none"; From 0c4f8f4790f1bbb0fd444a5fc1ff5238e0060525 Mon Sep 17 00:00:00 2001 From: momo609 <963372609@qq.com> Date: Tue, 26 Sep 2023 16:25:23 +0800 Subject: [PATCH 4/4] fix nms_rotated --- mmcv/ops/csrc/pytorch/npu/voxelization_npu.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmcv/ops/csrc/pytorch/npu/voxelization_npu.cpp b/mmcv/ops/csrc/pytorch/npu/voxelization_npu.cpp index 32b1a50cc1..ffd9b4c43b 100644 --- a/mmcv/ops/csrc/pytorch/npu/voxelization_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/voxelization_npu.cpp @@ -18,7 +18,7 @@ int hard_voxelize_forward_npu(const at::Tensor &points, at::Tensor &voxels, const std::vector coors_range, const int max_points, const int max_voxels, const int NDim = 3) { - at::Tensor voxel_num_tmp = OpPreparation::ApplyTensor(points, {1}); + at::Tensor voxel_num_tmp = at::empty({1}, points.options()); at::Tensor voxel_num = voxel_num_tmp.to(at::kInt); at::Tensor voxel_size_cpu = at::from_blob(