From 0694a21cc7a20b732df8a0774cb1ef24dcc1c206 Mon Sep 17 00:00:00 2001 From: ckirchhoff2021 <515629648@qq.com> Date: Tue, 20 Dec 2022 14:18:29 +0800 Subject: [PATCH 1/5] merge npu ops from master to 2.x --- mmcv/ops/csrc/common/pytorch_npu_helper.hpp | 35 ++++ mmcv/ops/csrc/pytorch/npu/deform_roi_pool.cpp | 63 +++++++ mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp | 162 ++++++++++++++++++ .../pytorch/npu/fused_bias_leakyrelu_npu.cpp | 54 ++++++ mmcv/ops/csrc/pytorch/npu/nms_npu.cpp | 45 +++++ mmcv/ops/masked_conv.py | 18 ++ mmcv/ops/modulated_deform_conv.py | 67 ++++++++ mmcv/utils/__init__.py | 7 +- mmcv/utils/device_type.py | 3 +- setup.py | 15 ++ tests/test_ops/test_deform_roi_pool.py | 6 +- tests/test_ops/test_focal_loss.py | 10 +- tests/test_ops/test_fused_bias_leakyrelu.py | 57 ++++-- 13 files changed, 520 insertions(+), 22 deletions(-) create mode 100644 mmcv/ops/csrc/common/pytorch_npu_helper.hpp create mode 100644 mmcv/ops/csrc/pytorch/npu/deform_roi_pool.cpp create mode 100644 mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp create mode 100644 mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp create mode 100644 mmcv/ops/csrc/pytorch/npu/nms_npu.cpp diff --git a/mmcv/ops/csrc/common/pytorch_npu_helper.hpp b/mmcv/ops/csrc/common/pytorch_npu_helper.hpp new file mode 100644 index 0000000000..88607d23b3 --- /dev/null +++ b/mmcv/ops/csrc/common/pytorch_npu_helper.hpp @@ -0,0 +1,35 @@ +/****************************************************************************** + * Copyright (c) 2022 Huawei Technologies Co., Ltd + * All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +#ifndef PYTORCH_NPU_HELPER_HPP_ +#define PYTORCH_NPU_HELPER_HPP_ + +#include +#include +#include + +#include "pytorch_cpp_helper.hpp" +#include "pytorch_device_registry.hpp" + +#define NPU_NAME_SPACE at_npu::native + +#define REGISTER_NPU_IMPL(key, value) REGISTER_DEVICE_IMPL(key, XLA, value) + +#define CHECK_NPU(x) \ + TORCH_CHECK(x.device().type() == at::kXLA, #x " must be a NPU tensor") + +#endif // PYTORCH_NPU_HELPER_HPP_ diff --git a/mmcv/ops/csrc/pytorch/npu/deform_roi_pool.cpp b/mmcv/ops/csrc/pytorch/npu/deform_roi_pool.cpp new file mode 100644 index 0000000000..0e9f2ee7ac --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/deform_roi_pool.cpp @@ -0,0 +1,63 @@ +#include "pytorch_npu_helper.hpp" + +using namespace NPU_NAME_SPACE; +using namespace std; + +void deform_roi_pool_forward_impl(Tensor input, Tensor rois, Tensor offset, + Tensor output, int pooled_height, + int pooled_width, float spatial_scale, + int sampling_ratio, float gamma); + +void deform_roi_pool_backward_impl(Tensor grad_output, Tensor input, + Tensor rois, Tensor offset, + Tensor grad_input, Tensor grad_offset, + int pooled_height, int pooled_width, + float spatial_scale, int sampling_ratio, + float gamma); + +void deform_roi_pool_forward_npu(Tensor input, Tensor rois, Tensor offset, + Tensor output, int pooled_height, + int pooled_width, float spatial_scale, + int sampling_ratio, float gamma) { + c10::SmallVector output_sizes = {pooled_height, pooled_width}; + at::IntArrayRef output_size = at::IntArrayRef(output_sizes); + int64_t sampling_ratio_ = (int64_t)sampling_ratio; + OpCommand cmd; + cmd.Name("DeformableRoiPool") + .Input(input) + .Input(rois) + .Input(offset) + .Output(output) + .Attr("spatial_scale", spatial_scale) + .Attr("output_size", output_size) + .Attr("sampling_ratio", sampling_ratio_) + .Attr("gamma", gamma) + .Run(); +} + +void deform_roi_pool_backward_npu(Tensor grad_output, Tensor input, Tensor rois, + Tensor offset, Tensor grad_input, + Tensor grad_offset, int pooled_height, + int pooled_width, float spatial_scale, + int sampling_ratio, float gamma) { + c10::SmallVector output_sizes = {pooled_height, pooled_width}; + at::IntArrayRef output_size = at::IntArrayRef(output_sizes); + int64_t sampling_ratio_ = (int64_t)sampling_ratio; + OpCommand cmd; + cmd.Name("DeformableRoiPoolGrad") + .Input(grad_input) + .Input(input) + .Input(rois) + .Input(offset) + .Output(grad_output) + .Output(grad_offset) + .Attr("output_size", output_size) + .Attr("spatial_scale", spatial_scale) + .Attr("sample_ratio", sampling_ratio_) + .Attr("gamma", gamma) + .Run(); +} + +REGISTER_NPU_IMPL(deform_roi_pool_forward_impl, deform_roi_pool_forward_npu); + +REGISTER_NPU_IMPL(deform_roi_pool_backward_impl, deform_roi_pool_backward_npu); diff --git a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp new file mode 100644 index 0000000000..c949bf9539 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp @@ -0,0 +1,162 @@ +#include "pytorch_npu_helper.hpp" + +using namespace NPU_NAME_SPACE; +using namespace std; + +void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, + Tensor output, float gamma, float alpha) { + int64_t n_class = input.size(1); + at::Tensor target_y = at::ones_like(input); + if (n_class == 1) { + target_y = at::reshape(target, input.sizes()); + target_y = at::mul(target_y, -1.0); + target_y = at::add(target_y, 1.0); + } else { + target_y = at_npu::native::NPUNativeFunctions::one_hot(target, n_class); + } + target_y = + at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); + int64_t weight_size = weight.size(0); + at::Tensor weight_y = at::ones_like(input); + if (weight_size > 0) { + weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, + input.sizes()); + } + OpCommand cmd; + string reduction = "none"; + cmd.Name("SigmoidFocalLoss") + .Input(input) + .Input(target_y) + .Input(weight_y) + .Output(output) + .Attr("gamma", gamma) + .Attr("alpha", alpha) + .Attr("reduction", reduction) + .Run(); +} + +void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, + Tensor output, float gamma, float alpha); + +void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, + Tensor grad_input, float gamma, + float alpha) { + int64_t n_class = input.size(1); + at::Tensor target_y = at::ones_like(input); + if (n_class == 1) { + target_y = at::reshape(target, input.sizes()); + } else { + target_y = at_npu::native::NPUNativeFunctions::one_hot(target, n_class); + target_y = at::mul(target_y, -1.0); + target_y = at::add(target_y, 1.0); + } + target_y = + at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); + at::Tensor grad_up = at::ones_like(input); + int64_t weight_size = weight.size(0); + at::Tensor weight_y = at::ones_like(input); + if (weight_size > 0) { + weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, + input.sizes()); + } + OpCommand cmd; + string reduction = "none"; + cmd.Name("SigmoidFocalLossGrad") + .Input(input) + .Input(target_y) + .Input(grad_up) + .Input(weight_y) + .Output(grad_input) + .Attr("gamma", gamma) + .Attr("alpha", alpha) + .Attr("reduction", reduction) + .Run(); +} + +void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target, + Tensor weight, Tensor grad_input, + float gamma, float alpha); + +void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, + Tensor output, float gamma, float alpha) { + int64_t n_class = input.size(1); + at::Tensor target_y = + at_npu::native::NPUNativeFunctions::one_hot(target, n_class); + target_y = + at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); + int64_t weight_size = weight.size(0); + at::Tensor weight_y = at::ones_like(input); + if (weight_size > 0) { + weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, + input.sizes()); + } + at::Tensor op_output = at::ones_like(input); + OpCommand cmd; + string reduction = "none"; + cmd.Name("SoftmaxFocalLoss") + .Input(input) + .Input(target_y) + .Input(weight_y) + .Output(op_output) + .Attr("gamma", gamma) + .Attr("alpha", alpha) + .Attr("reduction", reduction) + .Run(); + int64_t n_batch = input.size(0); + c10::SmallVector offsets = {0, 0}; + c10::SmallVector sizes = {n_batch, 1}; + at::IntArrayRef offset = at::IntArrayRef(offsets); + at::IntArrayRef size = at::IntArrayRef(sizes); + at_npu::native::NPUNativeFunctions::npu_slice_out(op_output, offset, size, + output); +} + +void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, + Tensor grad_input, float gamma, + float alpha); + +void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, + Tensor buff, Tensor grad_input, + float gamma, float alpha) { + int64_t n_class = input.size(1); + at::Tensor target_y = + at_npu::native::NPUNativeFunctions::one_hot(target, n_class); + target_y = + at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); + at::Tensor grad_up = at::ones_like(input); + int64_t weight_size = weight.size(0); + at::Tensor weight_y = at::ones_like(input); + if (weight_size > 0) { + weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, + input.sizes()); + } + OpCommand cmd; + string reduction = "none"; + cmd.Name("SoftmaxFocalLossGrad") + .Input(input) + .Input(target_y) + .Input(grad_up) + .Input(weight_y) + .Output(grad_input) + .Attr("gamma", gamma) + .Attr("alpha", alpha) + .Attr("reduction", reduction) + .Run(); +} + +void softmax_focal_loss_backward_impl(Tensor input, Tensor target, + Tensor weight, Tensor buff, + Tensor grad_input, float gamma, + float alpha); + +REGISTER_NPU_IMPL(sigmoid_focal_loss_forward_impl, + sigmoid_focal_loss_forward_npu); + +REGISTER_NPU_IMPL(sigmoid_focal_loss_backward_impl, + sigmoid_focal_loss_backward_npu); + +REGISTER_NPU_IMPL(softmax_focal_loss_forward_impl, + softmax_focal_loss_forward_npu); + +REGISTER_NPU_IMPL(softmax_focal_loss_backward_impl, + softmax_focal_loss_backward_npu); diff --git a/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp b/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp new file mode 100644 index 0000000000..cd052b5868 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp @@ -0,0 +1,54 @@ +#include "pytorch_npu_helper.hpp" + +using namespace NPU_NAME_SPACE; +using namespace std; + +Tensor fused_bias_leakyrelu_op_impl(const Tensor &input, const Tensor &bias, + const Tensor &refer, int act, int grad, + float alpha, float scale); + +Tensor fused_bias_leakyrelu_npu(const Tensor &input, const Tensor &bias, + const Tensor &refer, int act, int grad, + float alpha, float scale) { + at::Tensor py = at::empty_like(input); + // forward + if (grad == 0) { + auto input_size = input.sizes(); + int input_length = input_size.size(); + c10::SmallVector input_size_tmp; + input_size_tmp = array_to_small_vector(input_size); + if (input_length > 1) { + for (int i = 0; i < input_length; i++) { + if (i != 1) { + input_size_tmp[i] = 1; + } + } + } + at::Tensor bias_tmp = at::reshape(bias, input_size_tmp); + at::Tensor bias_ = at_npu::native::NPUNativeFunctions::npu_broadcast( + bias_tmp, input.sizes()); + OpCommand cmd; + cmd.Name("FusedBiasLeakyRelu") + .Input(input) + .Input(bias_) + .Output(py) + .Attr("scale", scale) + .Attr("negative_slope", alpha) + .Run(); + } + + // backward + if (grad == 1) { + OpCommand cmd; + cmd.Name("FusedBiasLeakyReluGrad") + .Input(input) + .Input(refer) + .Output(py) + .Attr("scale", scale) + .Attr("negative_slope", alpha) + .Run(); + } + return py; +} + +REGISTER_NPU_IMPL(fused_bias_leakyrelu_op_impl, fused_bias_leakyrelu_npu); diff --git a/mmcv/ops/csrc/pytorch/npu/nms_npu.cpp b/mmcv/ops/csrc/pytorch/npu/nms_npu.cpp new file mode 100644 index 0000000000..2f86893ea7 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/nms_npu.cpp @@ -0,0 +1,45 @@ +#include "pytorch_npu_helper.hpp" + +using namespace NPU_NAME_SPACE; +using namespace std; + +Tensor nms_npu(Tensor boxes, Tensor scores, float iou_threshold, int offset) { + at::Tensor boxed_offest = at_npu::native::OpPreparation::ApplyTensor(boxes); + at::Tensor ones_tensor = + at_npu::native::OpPreparation::ApplyTensor(boxes).fill_(1); + at::add_out(boxed_offest, boxes, ones_tensor, offset); + at::Tensor iou_threshold_y = at_npu::native::OpPreparation::ApplyTensor( + {}, boxes.options().dtype(at::kFloat), boxes) + .fill_(iou_threshold); + at::Tensor scores_threshold_y = + at_npu::native::OpPreparation::ApplyTensor( + {}, boxes.options().dtype(at::kFloat), boxes) + .fill_(0); + at::Tensor max_outputsize_y = at_npu::native::OpPreparation::ApplyTensor( + {}, boxes.options().dtype(at::kInt), boxes) + .fill_(boxes.size(0)); + c10::SmallVector outputsize = {boxes.size(0)}; + at::Tensor output = at_npu::native::OpPreparation::ApplyTensor( + outputsize, boxes.options().dtype(at::kInt), boxes) + .fill_(-1); + OpCommand cmd; + cmd.Name("NonMaxSuppressionV3") + .Input(boxes) + .Input(scores) + .Input(max_outputsize_y) + .Input(iou_threshold_y) + .Input(scores_threshold_y) + .Output(output) + .Run(); + auto outputsizeBool = at::gt(output, -1); + auto outputsizeInt = outputsizeBool.to(at::ScalarType::Int); + auto countLen = at::sum(outputsizeInt, at::ScalarType::Int); + at::Tensor actual_output = output.slice(0, 0, countLen.item().toLong()); + actual_output = at_npu::native::NPUNativeFunctions::npu_dtype_cast( + actual_output, at::kLong); + return actual_output; +} + +Tensor nms_impl(Tensor boxes, Tensor scores, float iou_threshold, int offset); + +REGISTER_NPU_IMPL(nms_impl, nms_npu); diff --git a/mmcv/ops/masked_conv.py b/mmcv/ops/masked_conv.py index e125c735ed..919702e9cb 100644 --- a/mmcv/ops/masked_conv.py +++ b/mmcv/ops/masked_conv.py @@ -45,6 +45,24 @@ def forward(ctx, 'Stride could not only be 1 in masked_conv2d currently.') out_channel, in_channel, kernel_h, kernel_w = weight.size() + if features.device.type == 'npu': + import torch_npu + output = torch_npu.npu_conv2d( + features, + weight, + bias, + stride=(stride_h, stride_w), + padding=(pad_h, pad_w), + dilation=(1, 1), + groups=1) + if mask.size()[1:] != output.size()[2:]: + raise ValueError( + 'The mask is inconsistent with the shape of output_conv.') + mask = mask > 0 + mask = mask.type(output.dtype) + output = output * mask + return output + batch_size = features.size(0) out_h = int( math.floor( diff --git a/mmcv/ops/modulated_deform_conv.py b/mmcv/ops/modulated_deform_conv.py index 81def5c48f..5ff1bcc51c 100644 --- a/mmcv/ops/modulated_deform_conv.py +++ b/mmcv/ops/modulated_deform_conv.py @@ -35,6 +35,66 @@ def symbolic(g, input, offset, mask, weight, bias, stride, padding, groups_i=groups, deform_groups_i=deform_groups) + @staticmethod + def _calculate_sort_index(kernel_h, kernel_w, deformable_group): + split_num = deformable_group * 2 * kernel_h * kernel_w + sort_index = list(range(split_num)) + sort_index_fp = (sort_index[1::2] + sort_index[::2]) + sort_index_bp_dict = {i: idx for idx, i in enumerate(sort_index)} + sort_index_bp = [sort_index_bp_dict[i] for i in sort_index] + sort_index_fp = torch.IntTensor(sort_index_fp) + sort_index_bp = torch.IntTensor(sort_index_bp) + sort_index_fp = sort_index_fp.npu() + sort_index_bp = sort_index_bp.npu() + return sort_index_fp, sort_index_bp + + @staticmethod + def _npu_forward(ctx, input_tensor, offset, mask, weight, bias): + _, _, kernel_h, kernel_w = weight.shape + conv2d_bias = bias if len(bias) > 0 else None + sort_index_fp, sort_index_bp = \ + ModulatedDeformConv2dFunction._calculate_sort_index( + kernel_w, kernel_h, ctx.deform_groups) + select_offset = offset.index_select(1, sort_index_fp) + offset_all = torch.cat([select_offset, mask], dim=1) + output, offset_out = torch.npu_deformable_conv2d( + input_tensor, + weight, + offset_all, + conv2d_bias, + kernel_size=[kernel_w, kernel_h], + stride=[1, 1, ctx.stride[0], ctx.stride[1]], + padding=[1, 1, ctx.padding[0], ctx.padding[1]], + dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]], + groups=ctx.groups, + deformable_groups=ctx.deform_groups, + modulated=True) + if weight.requires_grad or mask.requires_grad or offset.requires_grad \ + or input_tensor.requires_grad: + ctx.save_for_backward(input_tensor, weight, offset_out, offset_all, + sort_index_bp) + return output + + @staticmethod + def _npu_backward(ctx, grad_output): + input_tensor, weight, offset_out, offset_all, sort_index_bp = \ + ctx.saved_tensors + grad_input, grad_weight, grad_offset_all, grad_bias = \ + torch.npu_deformable_conv2dbk( + input_tensor, grad_output, offset_out, weight, offset_all, + kernel_size=[weight.shape[3], weight.shape[2]], + stride=[1, 1, ctx.stride[0], ctx.stride[1]], + padding=[1, 1, ctx.padding[0], ctx.padding[1]], + dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]], + groups=ctx.groups, deformable_groups=ctx.deform_groups, + modulated=True) + grad_offset = grad_offset_all.index_select(1, sort_index_bp) + grad_mask = grad_offset_all[:, grad_offset.shape[1]:, :, :] + if not ctx.with_bias: + grad_bias = None + return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, + None, None, None, None, None, None, None, None) + @staticmethod def forward(ctx, input: torch.Tensor, @@ -70,6 +130,10 @@ def forward(ctx, weight = weight.type_as(input) bias = bias.type_as(input) # type: ignore mask = mask.type_as(input) + if ctx.device == 'npu': + output = ModulatedDeformConv2dFunction._npu_forward( + ctx, input, offset, mask, weight, bias) + return output ctx.save_for_backward(input, offset, mask, weight, bias) output = input.new_empty( ModulatedDeformConv2dFunction._output_size(ctx, input, weight)) @@ -99,6 +163,9 @@ def forward(ctx, @staticmethod @once_differentiable def backward(ctx, grad_output: torch.Tensor) -> tuple: + if ctx.device == 'npu': + return ModulatedDeformConv2dFunction._npu_backward( + ctx, grad_output) input, offset, mask, weight, bias = ctx.saved_tensors grad_input = torch.zeros_like(input) grad_offset = torch.zeros_like(offset) diff --git a/mmcv/utils/__init__.py b/mmcv/utils/__init__.py index 242665a611..53ebb94537 100644 --- a/mmcv/utils/__init__.py +++ b/mmcv/utils/__init__.py @@ -1,9 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .device_type import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE +from .device_type import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, + IS_MPS_AVAILABLE, IS_NPU_AVAILABLE) from .env import collect_env from .parrots_jit import jit, skip_no_elena __all__ = [ - 'IS_MLU_AVAILABLE', 'IS_MPS_AVAILABLE', 'IS_CUDA_AVAILABLE', 'collect_env', - 'jit', 'skip_no_elena' + 'IS_MLU_AVAILABLE', 'IS_MPS_AVAILABLE', 'IS_CUDA_AVAILABLE', + 'IS_NPU_AVAILABLE', 'collect_env', 'jit', 'skip_no_elena' ] diff --git a/mmcv/utils/device_type.py b/mmcv/utils/device_type.py index 84b185e8e7..0a84371276 100644 --- a/mmcv/utils/device_type.py +++ b/mmcv/utils/device_type.py @@ -1,7 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmengine.device import (is_cuda_available, is_mlu_available, - is_mps_available) + is_mps_available, is_npu_available) IS_MLU_AVAILABLE = is_mlu_available() IS_MPS_AVAILABLE = is_mps_available() IS_CUDA_AVAILABLE = is_cuda_available() +IS_NPU_AVAILABLE = is_npu_available() diff --git a/setup.py b/setup.py index b23818fef5..85d011aad6 100644 --- a/setup.py +++ b/setup.py @@ -270,6 +270,21 @@ def get_extensions(): extension = CppExtension include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common')) include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/mps')) + elif (os.getenv('FORCE_NPU', '0') == '1'): + print(f'Compiling {ext_name} only with CPU and NPU') + try: + from torch_npu.utils.cpp_extension import NpuExtension + define_macros += [('MMCV_WITH_NPU', None)] + extension = NpuExtension + except Exception: + raise ImportError('can not find any torch_npu') + # src + op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \ + glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') + \ + glob.glob('./mmcv/ops/csrc/common/npu/*.cpp') + \ + glob.glob('./mmcv/ops/csrc/pytorch/npu/*.cpp') + include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common')) + include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/npu')) else: print(f'Compiling {ext_name} only with CPU') op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \ diff --git a/tests/test_ops/test_deform_roi_pool.py b/tests/test_ops/test_deform_roi_pool.py index 5c48e6f777..346301fe41 100644 --- a/tests/test_ops/test_deform_roi_pool.py +++ b/tests/test_ops/test_deform_roi_pool.py @@ -5,7 +5,7 @@ import pytest import torch -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE _USING_PARROTS = True try: @@ -126,6 +126,10 @@ def _test_deform_roi_pool_allclose(self, device, dtype=torch.float): assert np.allclose(x.grad.data.cpu().numpy(), np_grad, 1e-3) @pytest.mark.parametrize('device', [ + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')), pytest.param( 'cuda', marks=pytest.mark.skipif( diff --git a/tests/test_ops/test_focal_loss.py b/tests/test_ops/test_focal_loss.py index 316f58469d..ee7c9861ae 100644 --- a/tests/test_ops/test_focal_loss.py +++ b/tests/test_ops/test_focal_loss.py @@ -3,7 +3,7 @@ import pytest import torch -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE _USING_PARROTS = True try: @@ -130,6 +130,10 @@ def test_softmax_half(self): self._test_softmax(dtype=torch.half) @pytest.mark.parametrize('device', [ + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')), pytest.param( 'cuda', marks=pytest.mark.skipif( @@ -143,6 +147,10 @@ def test_sigmoid_float(self, device): self._test_sigmoid(device=device, dtype=torch.float) @pytest.mark.parametrize('device', [ + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')), pytest.param( 'cuda', marks=pytest.mark.skipif( diff --git a/tests/test_ops/test_fused_bias_leakyrelu.py b/tests/test_ops/test_fused_bias_leakyrelu.py index 47357860de..e6f6fb9f75 100644 --- a/tests/test_ops/test_fused_bias_leakyrelu.py +++ b/tests/test_ops/test_fused_bias_leakyrelu.py @@ -2,6 +2,8 @@ import pytest import torch +from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE + _USING_PARROTS = True try: from parrots.autograd import gradcheck @@ -14,36 +16,59 @@ class TestFusedBiasLeakyReLU: @classmethod def setup_class(cls): - if not torch.cuda.is_available(): + if not IS_CUDA_AVAILABLE and not IS_NPU_AVAILABLE: return - cls.input_tensor = torch.randn((2, 2, 2, 2), requires_grad=True).cuda() - cls.bias = torch.zeros(2, requires_grad=True).cuda() + if IS_CUDA_AVAILABLE: + cls.input_tensor = torch.randn((2, 2, 2, 2), + requires_grad=True).cuda() + cls.bias = torch.zeros(2, requires_grad=True).cuda() + elif IS_NPU_AVAILABLE: + cls.input_tensor = torch.randn((2, 2, 2, 2), + requires_grad=True).npu() + cls.bias = torch.zeros(2, requires_grad=True).npu() - @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') - def test_gradient(self): + @pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')) + ]) + def test_gradient(self, device): from mmcv.ops import FusedBiasLeakyReLU if _USING_PARROTS: - gradcheck( - FusedBiasLeakyReLU(2).cuda(), - self.input_tensor, - delta=1e-4, - pt_atol=1e-3) + if IS_CUDA_AVAILABLE: + gradcheck( + FusedBiasLeakyReLU(2).cuda(), + self.input_tensor, + delta=1e-4, + pt_atol=1e-3) else: gradcheck( - FusedBiasLeakyReLU(2).cuda(), + FusedBiasLeakyReLU(2).to(device), self.input_tensor, eps=1e-4, atol=1e-3) - @pytest.mark.skipif( - not torch.cuda.is_available() or _USING_PARROTS, - reason='requires cuda') - def test_gradgradient(self): + @pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')) + ]) + def test_gradgradient(self, device): from mmcv.ops import FusedBiasLeakyReLU gradgradcheck( - FusedBiasLeakyReLU(2).cuda(), + FusedBiasLeakyReLU(2).to(device), self.input_tensor, eps=1e-4, atol=1e-3) From 869e7f961e255fb36409b49231233564759c016d Mon Sep 17 00:00:00 2001 From: ckirchhoff2021 <515629648@qq.com> Date: Tue, 20 Dec 2022 17:09:56 +0800 Subject: [PATCH 2/5] BugFix: fix merge bugs --- mmcv/ops/focal_loss.py | 2 +- mmcv/ops/modulated_deform_conv.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/mmcv/ops/focal_loss.py b/mmcv/ops/focal_loss.py index 8d5ccce928..dcd65de318 100644 --- a/mmcv/ops/focal_loss.py +++ b/mmcv/ops/focal_loss.py @@ -117,7 +117,7 @@ def forward(ctx, weight: Optional[torch.Tensor] = None, reduction='mean') -> torch.Tensor: - assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor)) + assert target.dtype == torch.long assert input.dim() == 2 assert target.dim() == 1 assert input.size(0) == target.size(0) diff --git a/mmcv/ops/modulated_deform_conv.py b/mmcv/ops/modulated_deform_conv.py index 5ff1bcc51c..dc3c89d58c 100644 --- a/mmcv/ops/modulated_deform_conv.py +++ b/mmcv/ops/modulated_deform_conv.py @@ -117,6 +117,7 @@ def forward(ctx, ctx.groups = groups ctx.deform_groups = deform_groups ctx.with_bias = bias is not None + ctx.device = input.device.type if not ctx.with_bias: bias = input.new_empty(0) # fake tensor # When pytorch version >= 1.6.0, amp is adopted for fp16 mode; From aef335c2d800c8cb4b136a016a72854ca16b997b Mon Sep 17 00:00:00 2001 From: ckirchhoff2021 <515629648@qq.com> Date: Thu, 12 Jan 2023 16:01:39 +0800 Subject: [PATCH 3/5] {[Feature]: add psamask, roipool to 2.x, and fix the SigmoidFocalLoss assert condition --- docs/en/understand_mmcv/ops.md | 116 ++++++++++----------- docs/zh_cn/understand_mmcv/ops.md | 116 ++++++++++----------- mmcv/ops/csrc/pytorch/npu/psa_mask_npu.cpp | 75 +++++++++++++ mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp | 34 ++++++ mmcv/ops/deform_conv.py | 27 +++++ mmcv/ops/focal_loss.py | 3 +- mmcv/ops/modulated_deform_conv.py | 2 +- tests/test_ops/test_psa_mask.py | 14 ++- tests/test_ops/test_roi_pool.py | 8 +- 9 files changed, 271 insertions(+), 124 deletions(-) create mode 100644 mmcv/ops/csrc/pytorch/npu/psa_mask_npu.cpp create mode 100644 mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index e0a9a3648c..e4f3285984 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -2,61 +2,61 @@ We implement common ops used in detection, segmentation, etc. -| Device | CPU | CUDA | MLU | MPS | -| ---------------------------- | --- | ---- | --- | --- | -| ActiveRotatedFilter | √ | √ | | | -| AssignScoreWithK | | √ | | | -| BallQuery | | √ | | | -| BBoxOverlaps | | √ | √ | √ | -| BorderAlign | | √ | | | -| BoxIouRotated | √ | √ | | | -| BoxIouQuadri | √ | √ | | | -| CARAFE | | √ | √ | | -| ChamferDistance | | √ | | | -| CrissCrossAttention | | √ | | | -| ContourExpand | √ | | | | -| ConvexIoU | | √ | | | -| CornerPool | | √ | | | -| Correlation | | √ | | | -| Deformable Convolution v1/v2 | √ | √ | | | -| Deformable RoIPool | | √ | √ | | -| DiffIoURotated | | √ | | | -| DynamicScatter | | √ | | | -| FurthestPointSample | | √ | | | -| FurthestPointSampleWithDist | | √ | | | -| FusedBiasLeakyrelu | | √ | | | -| GatherPoints | | √ | | | -| GroupPoints | | √ | | | -| Iou3d | | √ | √ | | -| KNN | | √ | | | -| MaskedConv | | √ | √ | | -| MergeCells | | √ | | | -| MinAreaPolygon | | √ | | | -| ModulatedDeformConv2d | √ | √ | | | -| MultiScaleDeformableAttn | | √ | √ | | -| NMS | √ | √ | √ | | -| NMSRotated | √ | √ | | | -| NMSQuadri | √ | √ | | | -| PixelGroup | √ | | | | -| PointsInBoxes | √ | √ | | | -| PointsInPolygons | | √ | | | -| PSAMask | √ | √ | √ | | -| RotatedFeatureAlign | √ | √ | | | -| RoIPointPool3d | | √ | √ | | -| RoIPool | | √ | √ | | -| RoIAlignRotated | √ | √ | √ | | -| RiRoIAlignRotated | | √ | | | -| RoIAlign | √ | √ | √ | | -| RoIAwarePool3d | | √ | √ | | -| SAConv2d | | √ | | | -| SigmoidFocalLoss | | √ | √ | | -| SoftmaxFocalLoss | | √ | | | -| SoftNMS | | √ | | | -| Sparse Convolution | | √ | | | -| Synchronized BatchNorm | | √ | | | -| ThreeInterpolate | | √ | | | -| ThreeNN | | √ | √ | | -| TINShift | | √ | √ | | -| UpFirDn2d | | √ | | | -| Voxelization | √ | √ | | | -| PrRoIPool | | √ | | | +| Device | CPU | CUDA | MLU | MPS | Ascend | +| ---------------------------- | --- | ---- | --- | --- | ------ | +| ActiveRotatedFilter | √ | √ | | | | +| AssignScoreWithK | | √ | | | | +| BallQuery | | √ | | | | +| BBoxOverlaps | | √ | √ | √ | | +| BorderAlign | | √ | | | | +| BoxIouRotated | √ | √ | | | | +| BoxIouQuadri | √ | √ | | | | +| CARAFE | | √ | √ | | | +| ChamferDistance | | √ | | | | +| CrissCrossAttention | | √ | | | | +| ContourExpand | √ | | | | | +| ConvexIoU | | √ | | | | +| CornerPool | | √ | | | | +| Correlation | | √ | | | | +| Deformable Convolution v1/v2 | √ | √ | | | √ | +| Deformable RoIPool | | √ | √ | | √ | +| DiffIoURotated | | √ | | | | +| DynamicScatter | | √ | | | | +| FurthestPointSample | | √ | | | | +| FurthestPointSampleWithDist | | √ | | | | +| FusedBiasLeakyrelu | | √ | | | √ | +| GatherPoints | | √ | | | | +| GroupPoints | | √ | | | | +| Iou3d | | √ | √ | | | +| KNN | | √ | | | | +| MaskedConv | | √ | √ | | √ | +| MergeCells | | √ | | | | +| MinAreaPolygon | | √ | | | | +| ModulatedDeformConv2d | √ | √ | | | √ | +| MultiScaleDeformableAttn | | √ | √ | | | +| NMS | √ | √ | √ | | √ | +| NMSRotated | √ | √ | | | | +| NMSQuadri | √ | √ | | | | +| PixelGroup | √ | | | | | +| PointsInBoxes | √ | √ | | | | +| PointsInPolygons | | √ | | | | +| PSAMask | √ | √ | √ | | √ | +| RotatedFeatureAlign | √ | √ | | | | +| RoIPointPool3d | | √ | √ | | | +| RoIPool | | √ | √ | | √ | +| RoIAlignRotated | √ | √ | √ | | | +| RiRoIAlignRotated | | √ | | | | +| RoIAlign | √ | √ | √ | | | +| RoIAwarePool3d | | √ | √ | | | +| SAConv2d | | √ | | | | +| SigmoidFocalLoss | | √ | √ | | √ | +| SoftmaxFocalLoss | | √ | | | √ | +| SoftNMS | | √ | | | | +| Sparse Convolution | | √ | | | | +| Synchronized BatchNorm | | √ | | | | +| ThreeInterpolate | | √ | | | | +| ThreeNN | | √ | √ | | | +| TINShift | | √ | √ | | | +| UpFirDn2d | | √ | | | | +| Voxelization | √ | √ | | | | +| PrRoIPool | | √ | | | | diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index 6b4622146c..3d6eecca3b 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -2,61 +2,61 @@ MMCV 提供了检测、分割等任务中常用的算子 -| Device | CPU | CUDA | MLU | MPS | -| ---------------------------- | --- | ---- | --- | --- | -| ActiveRotatedFilter | √ | √ | | | -| AssignScoreWithK | | √ | | | -| BallQuery | | √ | | | -| BBoxOverlaps | | √ | √ | √ | -| BorderAlign | | √ | | | -| BoxIouRotated | √ | √ | | | -| BoxIouQuadri | √ | √ | | | -| CARAFE | | √ | √ | | -| ChamferDistance | | √ | | | -| CrissCrossAttention | | √ | | | -| ContourExpand | √ | | | | -| ConvexIoU | | √ | | | -| CornerPool | | √ | | | -| Correlation | | √ | | | -| Deformable Convolution v1/v2 | √ | √ | | | -| Deformable RoIPool | | √ | √ | | -| DiffIoURotated | | √ | | | -| DynamicScatter | | √ | | | -| FurthestPointSample | | √ | | | -| FurthestPointSampleWithDist | | √ | | | -| FusedBiasLeakyrelu | | √ | | | -| GatherPoints | | √ | | | -| GroupPoints | | √ | | | -| Iou3d | | √ | √ | | -| KNN | | √ | | | -| MaskedConv | | √ | √ | | -| MergeCells | | √ | | | -| MinAreaPolygon | | √ | | | -| ModulatedDeformConv2d | √ | √ | | | -| MultiScaleDeformableAttn | | √ | √ | | -| NMS | √ | √ | √ | | -| NMSRotated | √ | √ | | | -| NMSQuadri | √ | √ | | | -| PixelGroup | √ | | | | -| PointsInBoxes | √ | √ | | | -| PointsInPolygons | | √ | | | -| PSAMask | √ | √ | √ | | -| RotatedFeatureAlign | √ | √ | | | -| RoIPointPool3d | | √ | √ | | -| RoIPool | | √ | √ | | -| RoIAlignRotated | √ | √ | √ | | -| RiRoIAlignRotated | | √ | | | -| RoIAlign | √ | √ | √ | | -| RoIAwarePool3d | | √ | √ | | -| SAConv2d | | √ | | | -| SigmoidFocalLoss | | √ | √ | | -| SoftmaxFocalLoss | | √ | | | -| SoftNMS | | √ | | | -| Sparse Convolution | | √ | | | -| Synchronized BatchNorm | | √ | | | -| ThreeInterpolate | | √ | | | -| ThreeNN | | √ | √ | | -| TINShift | | √ | √ | | -| UpFirDn2d | | √ | | | -| Voxelization | √ | √ | | | -| PrRoIPool | | √ | | | +| Device | CPU | CUDA | MLU | MPS | Ascend | +| ---------------------------- | --- | ---- | --- | --- | ------ | +| ActiveRotatedFilter | √ | √ | | | | +| AssignScoreWithK | | √ | | | | +| BallQuery | | √ | | | | +| BBoxOverlaps | | √ | √ | √ | | +| BorderAlign | | √ | | | | +| BoxIouRotated | √ | √ | | | | +| BoxIouQuadri | √ | √ | | | | +| CARAFE | | √ | √ | | | +| ChamferDistance | | √ | | | | +| CrissCrossAttention | | √ | | | | +| ContourExpand | √ | | | | | +| ConvexIoU | | √ | | | | +| CornerPool | | √ | | | | +| Correlation | | √ | | | | +| Deformable Convolution v1/v2 | √ | √ | | | √ | +| Deformable RoIPool | | √ | √ | | √ | +| DiffIoURotated | | √ | | | | +| DynamicScatter | | √ | | | | +| FurthestPointSample | | √ | | | | +| FurthestPointSampleWithDist | | √ | | | | +| FusedBiasLeakyrelu | | √ | | | √ | +| GatherPoints | | √ | | | | +| GroupPoints | | √ | | | | +| Iou3d | | √ | √ | | | +| KNN | | √ | | | | +| MaskedConv | | √ | √ | | √ | +| MergeCells | | √ | | | | +| MinAreaPolygon | | √ | | | | +| ModulatedDeformConv2d | √ | √ | | | √ | +| MultiScaleDeformableAttn | | √ | √ | | | +| NMS | √ | √ | √ | | √ | +| NMSRotated | √ | √ | | | | +| NMSQuadri | √ | √ | | | | +| PixelGroup | √ | | | | | +| PointsInBoxes | √ | √ | | | | +| PointsInPolygons | | √ | | | | +| PSAMask | √ | √ | √ | | √ | +| RotatedFeatureAlign | √ | √ | | | | +| RoIPointPool3d | | √ | √ | | | +| RoIPool | | √ | √ | | √ | +| RoIAlignRotated | √ | √ | √ | | | +| RiRoIAlignRotated | | √ | | | | +| RoIAlign | √ | √ | √ | | | +| RoIAwarePool3d | | √ | √ | | | +| SAConv2d | | √ | | | | +| SigmoidFocalLoss | | √ | √ | | √ | +| SoftmaxFocalLoss | | √ | | | √ | +| SoftNMS | | √ | | | | +| Sparse Convolution | | √ | | | | +| Synchronized BatchNorm | | √ | | | | +| ThreeInterpolate | | √ | | | | +| ThreeNN | | √ | √ | | | +| TINShift | | √ | √ | | | +| UpFirDn2d | | √ | | | | +| Voxelization | √ | √ | | | | +| PrRoIPool | | √ | | | | diff --git a/mmcv/ops/csrc/pytorch/npu/psa_mask_npu.cpp b/mmcv/ops/csrc/pytorch/npu/psa_mask_npu.cpp new file mode 100644 index 0000000000..44ddb5431f --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/psa_mask_npu.cpp @@ -0,0 +1,75 @@ +#include "pytorch_npu_helper.hpp" + +using namespace NPU_NAME_SPACE; +using namespace std; + +void psamask_forward_npu(const int psa_type, const Tensor x, Tensor y, + const int num, const int h_feature, + const int w_feature, const int h_mask, + const int w_mask, const int half_h_mask, + const int half_w_mask) { + int64_t psa_type_i64 = psa_type; + int64_t num_i64 = num; + int64_t h_feature_i64 = h_feature; + int64_t w_feature_i64 = w_feature; + int64_t h_mask_i64 = h_mask; + int64_t w_mask_i64 = w_mask; + int64_t half_h_mask_i64 = half_h_mask; + int64_t half_w_mask_i64 = half_w_mask; + OpCommand cmd; + cmd.Name("PSAMask") + .Input(x) + .Output(y) + .Attr("psa_type", psa_type_i64) + .Attr("num", num_i64) + .Attr("h_feature", h_feature_i64) + .Attr("w_feature", w_feature_i64) + .Attr("h_mask", h_mask_i64) + .Attr("w_mask", w_mask_i64) + .Attr("half_h_mask", half_h_mask_i64) + .Attr("half_w_mask", half_w_mask_i64) + .Run(); +} + +void psamask_forward_impl(const int psa_type, const Tensor x, Tensor y, + const int num, const int h_feature, + const int w_feature, const int h_mask, + const int w_mask, const int half_h_mask, + const int half_w_mask); + +void psamask_backward_npu(const int psa_type, const Tensor y_grad, + Tensor x_grad, const int num, const int h_feature, + const int w_feature, const int h_mask, + const int w_mask, const int half_h_mask, + const int half_w_mask) { + int64_t psa_type_i64 = psa_type; + int64_t num_i64 = num; + int64_t h_feature_i64 = h_feature; + int64_t w_feature_i64 = w_feature; + int64_t h_mask_i64 = h_mask; + int64_t w_mask_i64 = w_mask; + int64_t half_h_mask_i64 = half_h_mask; + int64_t half_w_mask_i64 = half_w_mask; + OpCommand cmd; + cmd.Name("PSAMaskGrad") + .Input(y_grad) + .Output(x_grad) + .Attr("psa_type", psa_type_i64) + .Attr("num", num_i64) + .Attr("h_feature", h_feature_i64) + .Attr("w_feature", w_feature_i64) + .Attr("h_mask", h_mask_i64) + .Attr("w_mask", w_mask_i64) + .Attr("half_h_mask", half_h_mask_i64) + .Attr("half_w_mask", half_w_mask_i64) + .Run(); +} + +void psamask_backward_impl(const int psa_type, const Tensor y_grad, + Tensor x_grad, const int num, const int h_feature, + const int w_feature, const int h_mask, + const int w_mask, const int half_h_mask, + const int half_w_mask); + +REGISTER_NPU_IMPL(psamask_forward_impl, psamask_forward_npu); +REGISTER_NPU_IMPL(psamask_backward_impl, psamask_backward_npu); diff --git a/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp b/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp new file mode 100644 index 0000000000..36bd9c7a80 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp @@ -0,0 +1,34 @@ +#include "pytorch_npu_helper.hpp" + +using namespace NPU_NAME_SPACE; +using namespace std; + +void roi_pool_forward_npu(Tensor input, Tensor rois, Tensor output, + Tensor argmax, int pooled_height, int pooled_width, + float spatial_scale) { + int64_t pooled_height_64 = pooled_height; + int64_t pooled_width_64 = pooled_width; + int64_t pooled_channel = 1; + at::Tensor roi_actual_num = at_npu::native::OpPreparation::ApplyTensor( + {}, rois.options().dtype(at::kInt), rois); + + OpCommand cmd; + cmd.Name("RoiPoolingWithArgMax") + .Input(input) + .Input(rois) + .Input(roi_actual_num) + .Output(output) + .Output(argmax) + .Attr("pooled_h", pooled_height_64) + .Attr("pooled_w", pooled_width_64) + .Attr("spatial_scale_h", spatial_scale) + .Attr("spatial_scale_w", spatial_scale) + .Attr("pool_channel", pooled_channel) + .Run(); +} + +void roi_pool_forward_impl(Tensor input, Tensor rois, Tensor output, + Tensor argmax, int pooled_height, int pooled_width, + float spatial_scale); + +REGISTER_NPU_IMPL(roi_pool_forward_impl, roi_pool_forward_npu); diff --git a/mmcv/ops/deform_conv.py b/mmcv/ops/deform_conv.py index f30eb99b32..2c0b0898dc 100644 --- a/mmcv/ops/deform_conv.py +++ b/mmcv/ops/deform_conv.py @@ -13,6 +13,7 @@ from torch.nn.modules.utils import _pair, _single from ..utils import ext_loader +from .modulated_deform_conv import ModulatedDeformConv2dFunction ext_module = ext_loader.load_ext('_ext', [ 'deform_conv_forward', 'deform_conv_backward_input', @@ -47,6 +48,23 @@ def symbolic(g, bias_i=bias, im2col_step_i=im2col_step) + @staticmethod + def _npu_backward(ctx, grad_output): + input_tensor, weight, offset_out, offset_all, sort_index_for_npu_bp = \ + ctx.saved_tensors + grad_input, grad_weight, grad_offset_all, grad_bias = \ + torch.npu_deformable_conv2dbk( + input_tensor, grad_output, offset_out, weight, offset_all, + kernel_size=[weight.shape[3], weight.shape[2]], + stride=[1, 1, ctx.stride[0], ctx.stride[1]], + padding=[1, 1, ctx.padding[0], ctx.padding[1]], + dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]], + groups=ctx.groups, deformable_groups=ctx.deform_groups, + modulated=True) + grad_offset = grad_offset_all.index_select(1, sort_index_for_npu_bp) + return grad_input, grad_offset, grad_weight, \ + None, None, None, None, None, None, None + @staticmethod def forward(ctx, input: Tensor, @@ -80,6 +98,13 @@ def forward(ctx, # whatever the pytorch version is. input = input.type_as(offset) weight = weight.type_as(input) + if ctx.device == 'npu': + mask_shape, _ = torch.chunk(offset, 2, dim=1) + mask = torch.ones_like(mask_shape).to(input.device) + bias = input.new_empty(0) + output = ModulatedDeformConv2dFunction._npu_forward( + ctx, input, offset, mask, weight, bias) + return output ctx.save_for_backward(input, offset, weight) output = input.new_empty( @@ -116,6 +141,8 @@ def backward( ctx, grad_output: Tensor ) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor], None, None, None, None, None, None, None]: + if ctx.device == 'npu': + return DeformConv2dFunction._npu_backward(ctx, grad_output) input, offset, weight = ctx.saved_tensors grad_input = grad_offset = grad_weight = None diff --git a/mmcv/ops/focal_loss.py b/mmcv/ops/focal_loss.py index dcd65de318..69aab73052 100644 --- a/mmcv/ops/focal_loss.py +++ b/mmcv/ops/focal_loss.py @@ -25,8 +25,7 @@ def forward(ctx, weight: Optional[torch.Tensor] = None, reduction: str = 'mean') -> torch.Tensor: - assert isinstance( - target, (torch.Tensor, torch.LongTensor, torch.cuda.LongTensor)) + assert target.dtype == torch.long assert input.dim() == 2 assert target.dim() == 1 assert input.size(0) == target.size(0) diff --git a/mmcv/ops/modulated_deform_conv.py b/mmcv/ops/modulated_deform_conv.py index dc3c89d58c..acadc533bd 100644 --- a/mmcv/ops/modulated_deform_conv.py +++ b/mmcv/ops/modulated_deform_conv.py @@ -40,7 +40,7 @@ def _calculate_sort_index(kernel_h, kernel_w, deformable_group): split_num = deformable_group * 2 * kernel_h * kernel_w sort_index = list(range(split_num)) sort_index_fp = (sort_index[1::2] + sort_index[::2]) - sort_index_bp_dict = {i: idx for idx, i in enumerate(sort_index)} + sort_index_bp_dict = {i: idx for idx, i in enumerate(sort_index_fp)} sort_index_bp = [sort_index_bp_dict[i] for i in sort_index] sort_index_fp = torch.IntTensor(sort_index_fp) sort_index_bp = torch.IntTensor(sort_index_bp) diff --git a/tests/test_ops/test_psa_mask.py b/tests/test_ops/test_psa_mask.py index 8c1f3101ab..b0fd86e8f5 100644 --- a/tests/test_ops/test_psa_mask.py +++ b/tests/test_ops/test_psa_mask.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE class Loss(nn.Module): @@ -28,7 +28,11 @@ class TestPSAMask: pytest.param( 'mlu', marks=pytest.mark.skipif( - not IS_MLU_AVAILABLE, reason='requires MLU support')) + not IS_MLU_AVAILABLE, reason='requires MLU support')), + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')) ]) def test_psa_mask_collect(self, device): from mmcv.ops import PSAMask @@ -76,7 +80,11 @@ def test_psa_mask_collect(self, device): pytest.param( 'mlu', marks=pytest.mark.skipif( - not IS_MLU_AVAILABLE, reason='requires MLU support')) + not IS_MLU_AVAILABLE, reason='requires MLU support')), + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')) ]) def test_psa_mask_distribute(self, device): from mmcv.ops import PSAMask diff --git a/tests/test_ops/test_roi_pool.py b/tests/test_ops/test_roi_pool.py index 39d0ddea96..c935c81145 100644 --- a/tests/test_ops/test_roi_pool.py +++ b/tests/test_ops/test_roi_pool.py @@ -5,7 +5,7 @@ import pytest import torch -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE _USING_PARROTS = True try: @@ -86,7 +86,11 @@ def _test_roipool_allclose(self, device, dtype=torch.float): pytest.param( 'mlu', marks=pytest.mark.skipif( - not IS_MLU_AVAILABLE, reason='requires MLU support')) + not IS_MLU_AVAILABLE, reason='requires MLU support')), + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')) ]) @pytest.mark.parametrize('dtype', [ torch.float, From 8425903352a0583ca20599409296d5b2c42ce648 Mon Sep 17 00:00:00 2001 From: ckirchhoff2021 <515629648@qq.com> Date: Thu, 12 Jan 2023 16:19:03 +0800 Subject: [PATCH 4/5] merge conflicts in ops.md --- docs/en/understand_mmcv/ops.md | 1 + docs/zh_cn/understand_mmcv/ops.md | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index e4f3285984..5579cd7757 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -60,3 +60,4 @@ We implement common ops used in detection, segmentation, etc. | UpFirDn2d | | √ | | | | | Voxelization | √ | √ | | | | | PrRoIPool | | √ | | | | +| BezierAlign | √ | √ | | | | diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index 3d6eecca3b..fbc0f13386 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -60,3 +60,4 @@ MMCV 提供了检测、分割等任务中常用的算子 | UpFirDn2d | | √ | | | | | Voxelization | √ | √ | | | | | PrRoIPool | | √ | | | | +| BezierAlign | √ | √ | | | | From ca99d1750086d5a5314b938e9f50e44a5bd4c3d5 Mon Sep 17 00:00:00 2001 From: ckirchhoff2021 <515629648@qq.com> Date: Thu, 12 Jan 2023 17:04:19 +0800 Subject: [PATCH 5/5] [fix]: fix merge bug --- mmcv/ops/deform_conv.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mmcv/ops/deform_conv.py b/mmcv/ops/deform_conv.py index 2c0b0898dc..bc71b5c078 100644 --- a/mmcv/ops/deform_conv.py +++ b/mmcv/ops/deform_conv.py @@ -88,6 +88,7 @@ def forward(ctx, ctx.groups = groups ctx.deform_groups = deform_groups ctx.im2col_step = im2col_step + ctx.device = input.device.type # When pytorch version >= 1.6.0, amp is adopted for fp16 mode; # amp won't cast the type of model (float32), but "offset" is cast