Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support torch_npu v2.1 #2943

Merged
merged 4 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion mmcv/ops/csrc/common/pytorch_npu_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#ifndef PYTORCH_NPU_HELPER_HPP_
#define PYTORCH_NPU_HELPER_HPP_

#include <torch_npu/csrc/aten/NPUNativeFunctions.h>
#include <torch_npu/csrc/aten/CustomFunctions.h>
#include <torch_npu/csrc/framework/utils/CalcuOpUtil.h>
#include <torch_npu/csrc/framework/utils/OpAdapter.h>

Expand All @@ -27,9 +27,21 @@

#define NPU_NAME_SPACE at_npu::native

#ifdef MMCV_WITH_XLA
#define REGISTER_NPU_IMPL(key, value) REGISTER_DEVICE_IMPL(key, XLA, value)
#else
#define REGISTER_NPU_IMPL(key, value) \
REGISTER_DEVICE_IMPL(key, PrivateUse1, value)
#endif

#ifdef MMCV_WITH_XLA
#define CHECK_NPU(x) \
TORCH_CHECK(x.device().type() == at::kXLA, #x " must be a NPU tensor")
#else
#define CHECK_NPU(x) \
TORCH_CHECK(x.device().type() == at::kPrivateUse1, #x \
" must be a NPU " \
"tensor")
#endif

#endif // PYTORCH_NPU_HELPER_HPP_
8 changes: 5 additions & 3 deletions mmcv/ops/csrc/pytorch/nms_rotated.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@ Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order,
#else
AT_ERROR("Not compiled with GPU support");
#endif
#ifdef MMCV_WITH_XLA
} else if (dets.device().type() == at::kXLA) {
#ifdef MMCV_WITH_NPU
return nms_rotated_npu(dets, scores, labels, iou_threshold);
#else
AT_ERROR("Not compiled with NPU support");
#endif
#ifdef MMCV_WITH_KPRIVATE
} else if (dets.device().type() == at::kPrivateUse1) {
return nms_rotated_npu(dets, scores, labels, iou_threshold);
#endif
#ifdef MMCV_WITH_MLU
} else if (dets.device().type() == at::kMLU) {
Expand Down
12 changes: 6 additions & 6 deletions mmcv/ops/csrc/pytorch/npu/bbox_overlaps_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ void bbox_overlaps_npu(const Tensor bboxes1, const Tensor bboxes2, Tensor ious,
bboxesFP32 = bboxes1;
gtboxesFP32 = bboxes2;
}
if (bboxes2.scalar_type() != at::ScalarType::Float) {
bboxesFP32 = NPUNativeFunctions::npu_dtype_cast(bboxesFP32, at::kFloat);
gtboxesFP32 = NPUNativeFunctions::npu_dtype_cast(gtboxesFP32, at::kFloat);
if (bboxes2.scalar_type() != at::kFloat) {
bboxesFP32 = bboxesFP32.to(at::kFloat);
gtboxesFP32 = gtboxesFP32.to(at::kFloat);
}
c10::SmallVector<int64_t, SIZE> iousSize = {gtboxesFP32.size(0),
bboxesFP32.size(0)};
if (aligned) {
iousSize = {gtboxesFP32.size(0), 1};
}
at::Tensor iousFP32 = OpPreparation::ApplyTensor(bboxesFP32, iousSize);
at::Tensor iousFP32 = at::empty(iousSize, bboxesFP32.options());
bboxesFP32 = aligned ? bboxesFP32.transpose(0, 1) : bboxesFP32;
gtboxesFP32 = aligned ? gtboxesFP32.transpose(0, 1) : gtboxesFP32;
OpCommand cmd;
Expand All @@ -41,8 +41,8 @@ void bbox_overlaps_npu(const Tensor bboxes1, const Tensor bboxes2, Tensor ious,
.Attr("eps", (float)offset)
.Attr("aligned", aligned)
.Run();
if (bboxes2.scalar_type() != at::ScalarType::Float) {
iousFP32 = NPUNativeFunctions::npu_dtype_cast(iousFP32, at::kHalf);
if (bboxes2.scalar_type() != at::kFloat) {
iousFP32 = iousFP32.to(at::kHalf);
}
iousFP32 = swap_flag ? iousFP32.transpose(0, 1) : iousFP32;
ious.copy_(iousFP32);
Expand Down
37 changes: 13 additions & 24 deletions mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,13 @@ void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
target_y = at::mul(target_y, -1.0);
target_y = at::add(target_y, 1.0);
} else {
target_y = at_npu::native::NPUNativeFunctions::one_hot(target, n_class);
target_y = at::one_hot(target, n_class);
}
target_y =
at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt);
target_y = target_y.to(at::kInt);
int64_t weight_size = weight.size(0);
at::Tensor weight_y = at::ones_like(input);
if (weight_size > 0) {
weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight,
input.sizes());
weight_y = at::broadcast_to(weight, input.sizes());
}
OpCommand cmd;
string reduction = "none";
Expand All @@ -46,18 +44,16 @@ void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight,
if (n_class == 1) {
target_y = at::reshape(target, input.sizes());
} else {
target_y = at_npu::native::NPUNativeFunctions::one_hot(target, n_class);
target_y = at::one_hot(target, n_class);
target_y = at::mul(target_y, -1.0);
target_y = at::add(target_y, 1.0);
}
target_y =
at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt);
target_y = target_y.to(at::kInt);
at::Tensor grad_up = at::ones_like(input);
int64_t weight_size = weight.size(0);
at::Tensor weight_y = at::ones_like(input);
if (weight_size > 0) {
weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight,
input.sizes());
weight_y = at::broadcast_to(weight, input.sizes());
}
OpCommand cmd;
string reduction = "none";
Expand All @@ -80,15 +76,12 @@ void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target,
void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha) {
int64_t n_class = input.size(1);
at::Tensor target_y =
at_npu::native::NPUNativeFunctions::one_hot(target, n_class);
target_y =
at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt);
at::Tensor target_y = at::one_hot(target, n_class);
target_y = target_y.to(at::kInt);
int64_t weight_size = weight.size(0);
at::Tensor weight_y = at::ones_like(input);
if (weight_size > 0) {
weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight,
input.sizes());
weight_y = at::broadcast_to(weight, input.sizes());
}
at::Tensor op_output = at::ones_like(input);
OpCommand cmd;
Expand All @@ -107,8 +100,7 @@ void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
c10::SmallVector<int64_t, 2> sizes = {n_batch, 1};
at::IntArrayRef offset = at::IntArrayRef(offsets);
at::IntArrayRef size = at::IntArrayRef(sizes);
at_npu::native::NPUNativeFunctions::npu_slice_out(op_output, offset, size,
output);
at_npu::native::custom_ops::npu_slice_out(op_output, offset, size, output);
}

void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
Expand All @@ -119,16 +111,13 @@ void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight,
Tensor buff, Tensor grad_input,
float gamma, float alpha) {
int64_t n_class = input.size(1);
at::Tensor target_y =
at_npu::native::NPUNativeFunctions::one_hot(target, n_class);
target_y =
at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt);
at::Tensor target_y = at::one_hot(target, n_class);
target_y = target_y.to(at::kInt);
at::Tensor grad_up = at::ones_like(input);
int64_t weight_size = weight.size(0);
at::Tensor weight_y = at::ones_like(input);
if (weight_size > 0) {
weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight,
input.sizes());
weight_y = at::broadcast_to(weight, input.sizes());
}
OpCommand cmd;
string reduction = "none";
Expand Down
3 changes: 1 addition & 2 deletions mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ Tensor fused_bias_leakyrelu_npu(const Tensor &input, const Tensor &bias,
}
}
at::Tensor bias_tmp = at::reshape(bias, input_size_tmp);
at::Tensor bias_ = at_npu::native::NPUNativeFunctions::npu_broadcast(
bias_tmp, input.sizes());
at::Tensor bias_ = at::broadcast_to(bias_tmp, input.sizes());
OpCommand cmd;
cmd.Name("FusedBiasLeakyRelu")
.Input(input)
Expand Down
26 changes: 10 additions & 16 deletions mmcv/ops/csrc/pytorch/npu/nms_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,15 @@ Tensor nms_npu(Tensor boxes, Tensor scores, float iou_threshold, int offset) {
TORCH_CHECK((boxes.scalar_type() == at::ScalarType::Float),
"The type of boxes tensor passed in nms_npu should be float");
int64_t offset_64 = offset;
at::Tensor iou_threshold_y = at_npu::native::OpPreparation::ApplyTensor(
{}, boxes.options().dtype(at::kFloat), boxes)
.fill_(iou_threshold);
at::Tensor iou_threshold_y =
at::empty({}, boxes.options().dtype(at::kFloat)).fill_(iou_threshold);
at::Tensor scores_threshold_y =
at_npu::native::OpPreparation::ApplyTensor(
{}, boxes.options().dtype(at::kFloat), boxes)
.fill_(0);
at::Tensor max_outputsize_y = at_npu::native::OpPreparation::ApplyTensor(
{}, boxes.options().dtype(at::kInt), boxes)
.fill_(boxes.size(0));
at::empty({}, boxes.options().dtype(at::kFloat)).fill_(0);
at::Tensor max_outputsize_y =
at::empty({}, boxes.options().dtype(at::kInt)).fill_(boxes.size(0));
c10::SmallVector<int64_t, SIZE> outputsize = {boxes.size(0)};
at::Tensor output = at_npu::native::OpPreparation::ApplyTensor(
outputsize, boxes.options().dtype(at::kInt), boxes)
.fill_(-1);
at::Tensor output =
at::empty(outputsize, boxes.options().dtype(at::kInt)).fill_(-1);
OpCommand cmd;
cmd.Name("NonMaxSuppressionV3")
.Input(boxes)
Expand All @@ -32,11 +27,10 @@ Tensor nms_npu(Tensor boxes, Tensor scores, float iou_threshold, int offset) {
.Output(output)
.Run();
auto outputsizeBool = at::gt(output, -1);
auto outputsizeInt = outputsizeBool.to(at::ScalarType::Int);
auto countLen = at::sum(outputsizeInt, at::ScalarType::Int);
auto outputsizeInt = outputsizeBool.to(at::kInt);
auto countLen = at::sum(outputsizeInt, at::kInt);
at::Tensor actual_output = output.slice(0, 0, countLen.item().toLong());
actual_output = at_npu::native::NPUNativeFunctions::npu_dtype_cast(
actual_output, at::kLong);
actual_output = actual_output.to(at::kLong);
return actual_output;
}

Expand Down
14 changes: 7 additions & 7 deletions mmcv/ops/csrc/pytorch/npu/nms_rotated_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ Tensor nms_rotated_npu(const Tensor dets, const Tensor scores,
auto originDtype = dets.scalar_type();
at::Tensor detsCast = dets;
at::Tensor scoresCast = scores;
if (originDtype != at::ScalarType::Float) {
detsCast = NPUNativeFunctions::npu_dtype_cast(dets, at::kFloat);
scoresCast = NPUNativeFunctions::npu_dtype_cast(scores, at::kFloat);
if (originDtype != at::kFloat) {
detsCast = detsCast.to(at::kFloat);
scoresCast = scoresCast.to(at::kFloat);
}
c10::SmallVector<int64_t, SIZE> selectedIndexSize = {dets.size(0)};
at::Tensor selectedBox = OpPreparation::ApplyTensor(dets);
at::Tensor selectedIndex = OpPreparation::ApplyTensor(
selectedIndexSize, dets.options().dtype(at::kInt), dets);
at::Tensor selectedBox = at::empty_like(dets);
at::Tensor selectedIndex =
at::empty(selectedIndexSize, dets.options().dtype(at::kInt));

c10::SmallVector<int64_t, N> output_sync_idx = {0, 1};
OpCommand cmd;
Expand All @@ -27,6 +27,6 @@ Tensor nms_rotated_npu(const Tensor dets, const Tensor scores,
.Output(selectedIndex)
.Attr("iou_threshold", (float)iou_threshold)
.Run();
selectedIndex = NPUNativeFunctions::npu_dtype_cast(selectedIndex, at::kLong);
selectedIndex = selectedIndex.to(at::kLong);
return selectedIndex;
}
2 changes: 1 addition & 1 deletion mmcv/ops/csrc/pytorch/npu/roi_align_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ void roi_align_backward_npu(Tensor grad_output, Tensor rois, Tensor argmax_y,
int64_t sampling_ratio_64 = sampling_ratio;
int64_t roi_end_mode = 0;
c10::SmallVector<int64_t, SIZE> xdiff_shape =
at_npu::native::array_to_small_vector(grad_input.sizes());
array_to_small_vector(grad_input.sizes());
OpCommand cmd;
cmd.Name("ROIAlignGrad")
.Input(grad_output)
Expand Down
8 changes: 4 additions & 4 deletions mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ void roi_pool_forward_npu(Tensor input, Tensor rois, Tensor output,
int64_t pooled_height_64 = pooled_height;
int64_t pooled_width_64 = pooled_width;
int64_t pooled_channel = 1;
at::Tensor roi_actual_num = at_npu::native::OpPreparation::ApplyTensor(
{}, rois.options().dtype(at::kInt), rois);
at::Tensor roi_actual_num =
at::empty_like(rois, rois.options().dtype(at::kInt));
OpCommand cmd;
cmd.Name("RoiPoolingWithArgMax")
.Input(input)
Expand All @@ -32,8 +32,8 @@ void roi_pool_backward_npu(Tensor grad_output, Tensor rois, Tensor argmax,
int64_t pooled_height_64 = pooled_height;
int64_t pooled_width_64 = pooled_width;
int64_t pooled_channel = 1;
at::Tensor roi_actual_num = at_npu::native::OpPreparation::ApplyTensor(
{}, rois.options().dtype(at::kInt), rois);
at::Tensor roi_actual_num =
at::empty_like(rois, rois.options().dtype(at::kInt));
at::Tensor x = at::ones_like(grad_input);
OpCommand cmd;
cmd.Name("RoiPoolingGradWithArgMax")
Expand Down
5 changes: 2 additions & 3 deletions mmcv/ops/csrc/pytorch/npu/voxelization_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ int hard_voxelize_forward_npu(const at::Tensor &points, at::Tensor &voxels,
const std::vector<float> coors_range,
const int max_points, const int max_voxels,
const int NDim = 3) {
at::Tensor voxel_num_tmp = OpPreparation::ApplyTensor(points, {1});
at::Tensor voxel_num = at_npu::native::NPUNativeFunctions::npu_dtype_cast(
voxel_num_tmp, at::kInt);
at::Tensor voxel_num_tmp = at::empty({1}, points.options());
at::Tensor voxel_num = voxel_num_tmp.to(at::kInt);

at::Tensor voxel_size_cpu = at::from_blob(
const_cast<float *>(voxel_size.data()), {3}, dtype(at::kFloat));
Expand Down
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import platform
import re
import warnings
from pkg_resources import DistributionNotFound, get_distribution
from pkg_resources import DistributionNotFound, get_distribution, parse_version
from setuptools import find_packages, setup

EXT_TYPE = ''
Expand Down Expand Up @@ -428,6 +428,10 @@ def get_mluops_version(file_path):
from torch_npu.utils.cpp_extension import NpuExtension
define_macros += [('MMCV_WITH_NPU', None)]
extension = NpuExtension
if parse_version(torch.__version__) <= parse_version('2.0.0'):
define_macros += [('MMCV_WITH_XLA', None)]
if parse_version(torch.__version__) > parse_version('2.0.0'):
define_macros += [('MMCV_WITH_KPRIVATE', None)]
except Exception:
raise ImportError('can not find any torch_npu')
# src
Expand Down
Loading