From 093f46210101f437c9e25b613d64770f1779cd94 Mon Sep 17 00:00:00 2001 From: wangjiangben Date: Wed, 14 Sep 2022 12:02:27 +0800 Subject: [PATCH 01/67] init npu --- mmcv/device/__init__.py | 6 +- mmcv/device/npu/__init__.py | 6 ++ mmcv/device/npu/_functions.py | 25 ++++++++ mmcv/device/npu/data_parallel.py | 59 ++++++++++++++++++ mmcv/device/npu/distributed.py | 26 ++++++++ mmcv/device/npu/scatter_gather.py | 60 +++++++++++++++++++ mmcv/device/utils.py | 12 +++- mmcv/runner/dist_utils.py | 10 +++- mmcv/runner/hooks/optimizer.py | 8 ++- mmcv/utils/__init__.py | 4 +- mmcv/utils/device_type.py | 13 ++++ tests/test_device/test_device_utils.py | 7 ++- tests/test_device/test_functions.py | 13 +++- .../test_device/test_npu/test_npu_parallel.py | 37 ++++++++++++ 14 files changed, 274 insertions(+), 12 deletions(-) create mode 100644 mmcv/device/npu/__init__.py create mode 100644 mmcv/device/npu/_functions.py create mode 100644 mmcv/device/npu/data_parallel.py create mode 100644 mmcv/device/npu/distributed.py create mode 100644 mmcv/device/npu/scatter_gather.py create mode 100644 tests/test_device/test_npu/test_npu_parallel.py diff --git a/mmcv/device/__init__.py b/mmcv/device/__init__.py index ba217b0771..996f0ed391 100644 --- a/mmcv/device/__init__.py +++ b/mmcv/device/__init__.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from . import ipu, mlu, mps +from . import ipu, mlu, mps, npu from .scatter_gather import scatter, scatter_kwargs from .utils import get_device -__all__ = ['mlu', 'ipu', 'mps', 'get_device', 'scatter', 'scatter_kwargs'] +__all__ = [ + 'npu', 'mlu', 'ipu', 'mps', 'get_device', 'scatter', 'scatter_kwargs' +] diff --git a/mmcv/device/npu/__init__.py b/mmcv/device/npu/__init__.py new file mode 100644 index 0000000000..1a93b39678 --- /dev/null +++ b/mmcv/device/npu/__init__.py @@ -0,0 +1,6 @@ +# Copyright Huawei Technologies Co., Ltd. All rights reserved. +# Copyright (c) OpenMMLab. All rights reserved. +from .data_parallel import NPUDataParallel +from .distributed import NPUDistributedDataParallel + +__all__ = ['NPUDataParallel', 'NPUDistributedDataParallel'] diff --git a/mmcv/device/npu/_functions.py b/mmcv/device/npu/_functions.py new file mode 100644 index 0000000000..e022e59cd2 --- /dev/null +++ b/mmcv/device/npu/_functions.py @@ -0,0 +1,25 @@ +# Copyright Huawei Technologies Co., Ltd. All rights reserved. +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Union + +import torch + + +def scatter(input: Union[List, torch.Tensor], devices: List) -> List: + """scatter copies tensor to NPU directly.""" + if isinstance(input, list): + outputs = [scatter(_input, devices) for _input in input] + return outputs + elif isinstance(input, torch.Tensor): + output = input.contiguous() + return output.to('npu') if devices != [-1] else output + else: + raise Exception(f'Unknown type {type(input)}.') + + +class Scatter: + + @staticmethod + def forward(target_npus, input): + outputs = scatter(input, target_npus) + return tuple(outputs) if isinstance(outputs, list) else (outputs, ) diff --git a/mmcv/device/npu/data_parallel.py b/mmcv/device/npu/data_parallel.py new file mode 100644 index 0000000000..31352398b5 --- /dev/null +++ b/mmcv/device/npu/data_parallel.py @@ -0,0 +1,59 @@ +# Copyright Huawei Technologies Co., Ltd. All rights reserved. +# Copyright (c) OpenMMLab. All rights reserved. + +import sys + +import torch + +from mmcv.parallel import MMDataParallel +from .scatter_gather import scatter_kwargs + + +def _check_balance(*args, **kwargs): + return + + +# Since we do not have a similar hardware unit multi_processor +# on the NPU, the corresponding# devices_properties does not +# have this property and cannot be checked. So we masked the +# _check_balance function in DataParallel to make initialization pass. +for m in sys.modules: + if m.startswith('torch') or 'mmcv' in m: + if getattr(sys.modules[m], '_check_balance', None) is not None: + setattr(sys.modules[m], '_check_balance', _check_balance) + + +class NPUDataParallel(MMDataParallel): + """The NPUDataParallel module that supports DataContainer. + + NPUDataParallel is a class inherited from MMDataParall, which supports + NPU training and inference only. + + The main differences with MMDataParallel: + + - It only supports single-card of NPU, and only use first card to + run training and inference. + + - It uses direct host-to-device copy instead of stream-background + scatter. + + .. warning:: + NPUDataParallel only supports single NPU training, if you need to + train with multiple NPUs, please use NPUDistributedDataParallel + instead. If you have multiple NPUs, you can toggle device_ids + parameters passed in for this function to specify the running device. + + Args: + module (:class:`nn.Module`): Module to be encapsulated. + dim (int): Dimension used to scatter the data. Defaults to 0. + """ + + def __init__(self, *args, dim=0, **kwargs): + super().__init__(*args, dim=dim, **kwargs) + device_id = kwargs.get('device_ids', [0])[0] + self.device_ids = [device_id] + self.src_device_obj = torch.device(f'npu:{device_id}') + torch.npu.set_device(self.src_device_obj) + + def scatter(self, inputs, kwargs, device_ids): + return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) diff --git a/mmcv/device/npu/distributed.py b/mmcv/device/npu/distributed.py new file mode 100644 index 0000000000..f9e9ce46e8 --- /dev/null +++ b/mmcv/device/npu/distributed.py @@ -0,0 +1,26 @@ +# Copyright Huawei Technologies Co., Ltd. All rights reserved. +# Copyright (c) OpenMMLab. All rights reserved. + +from mmcv.parallel import MMDistributedDataParallel +from .scatter_gather import scatter_kwargs + + +class NPUDistributedDataParallel(MMDistributedDataParallel): + """The DDP module supports DataContainer. + + NPUDDP has one difference from MMDDP which moves data to NPU with coping + instead of scattering. + """ + + def to_kwargs(self, inputs, kwargs, device_id): + # Use `self.to_kwargs` instead of `self.scatter` in pytorch1.8 + # to move all tensors to device_id + return scatter_kwargs(inputs, kwargs, [device_id], dim=self.dim) + + def scatter(self, inputs, kwargs, device_ids): + return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) + + def forward(self, *inputs, **kwargs): + if self.device_ids: + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + return super().forward(*inputs[0], **kwargs[0]) diff --git a/mmcv/device/npu/scatter_gather.py b/mmcv/device/npu/scatter_gather.py new file mode 100644 index 0000000000..9b277d22f7 --- /dev/null +++ b/mmcv/device/npu/scatter_gather.py @@ -0,0 +1,60 @@ +# Copyright Huawei Technologies Co., Ltd. All rights reserved. +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmcv.parallel.data_container import DataContainer +from ._functions import Scatter + + +def scatter(inputs, target_npus, dim=0): + """Scatter inputs to target npu. + + The only difference from original :func:`scatter` is to add support for + :type:`~mmcv.parallel.DataContainer`. + """ + + def scatter_map(obj): + if isinstance(obj, torch.Tensor): + if target_npus != [-1]: + obj = obj.to('npu') + return [obj] + else: + # for CPU inference we use self-implemented scatter + return Scatter.forward(target_npus, obj) + if isinstance(obj, DataContainer): + if obj.cpu_only: + return obj.data + else: + return Scatter.forward(target_npus, obj.data) + if isinstance(obj, tuple) and len(obj) > 0: + return list(zip(*map(scatter_map, obj))) + if isinstance(obj, list) and len(obj) > 0: + out = list(map(list, zip(*map(scatter_map, obj)))) + return out + if isinstance(obj, dict) and len(obj) > 0: + out = list(map(type(obj), zip(*map(scatter_map, obj.items())))) + return out + return [obj for targets in target_npus] + + # After scatter_map is called, a scatter_map cell will exist. This cell + # has a reference to the actual function scatter_map, which has references + # to a closure that has a reference to the scatter_map cell (because the + # fn is recursive). To avoid this reference cycle, we set the function to + # None, clearing the cell + try: + return scatter_map(inputs) + finally: + scatter_map = None + + +def scatter_kwargs(inputs, kwargs, target_npus, dim=0): + """Scatter with support for kwargs dictionary.""" + inputs = scatter(inputs, target_npus, dim) if inputs else [] + kwargs = scatter(kwargs, target_npus, dim) if kwargs else [] + if len(inputs) < len(kwargs): + inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) + elif len(kwargs) < len(inputs): + kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) + inputs = tuple(inputs) + kwargs = tuple(kwargs) + return inputs, kwargs diff --git a/mmcv/device/utils.py b/mmcv/device/utils.py index e2adec08dd..acdb473bcd 100644 --- a/mmcv/device/utils.py +++ b/mmcv/device/utils.py @@ -1,14 +1,22 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE +from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE, + IS_NPU_AVAILABLE) def get_device() -> str: """Returns the currently existing device type. + .. note:: + Since npu provides tools to automatically convert cuda functions, + we need to make judgments on npu first to avoid entering + the cuda branch when using npu. + Returns: str: cuda | mlu | mps | cpu. """ - if IS_CUDA_AVAILABLE: + if IS_NPU_AVAILABLE: + return 'npu' + elif IS_CUDA_AVAILABLE: return 'cuda' elif IS_MLU_AVAILABLE: return 'mlu' diff --git a/mmcv/runner/dist_utils.py b/mmcv/runner/dist_utils.py index ee55dfda36..45f73c9b0f 100644 --- a/mmcv/runner/dist_utils.py +++ b/mmcv/runner/dist_utils.py @@ -13,7 +13,7 @@ from torch._utils import (_flatten_dense_tensors, _take_tensors, _unflatten_dense_tensors) -from mmcv.utils import IS_MLU_AVAILABLE +from mmcv.utils import IS_MLU_AVAILABLE, IS_NPU_AVAILABLE def _find_free_port() -> str: @@ -58,6 +58,14 @@ def _init_dist_pytorch(backend: str, **kwargs) -> None: rank=rank, world_size=int(os.environ['WORLD_SIZE']), **kwargs) + elif IS_NPU_AVAILABLE: + import torch_npu # noqa: F401 + torch.npu.set_device(rank) + dist.init_process_group( + backend='hccl', + rank=rank, + world_size=int(os.environ['WORLD_SIZE']), + **kwargs) else: num_gpus = torch.cuda.device_count() torch.cuda.set_device(rank % num_gpus) diff --git a/mmcv/runner/hooks/optimizer.py b/mmcv/runner/hooks/optimizer.py index 5a2caec6f6..9301547501 100644 --- a/mmcv/runner/hooks/optimizer.py +++ b/mmcv/runner/hooks/optimizer.py @@ -9,7 +9,8 @@ from torch import Tensor from torch.nn.utils import clip_grad -from mmcv.utils import TORCH_VERSION, _BatchNorm, digit_version +from mmcv.utils import (IS_NPU_AVAILABLE, TORCH_VERSION, _BatchNorm, + digit_version) from ..dist_utils import allreduce_grads from ..fp16_utils import LossScaler, wrap_fp16_model from .hook import HOOKS, Hook @@ -17,7 +18,10 @@ try: # If PyTorch version >= 1.6.0, torch.cuda.amp.GradScaler would be imported # and used; otherwise, auto fp16 will adopt mmcv's implementation. - from torch.cuda.amp import GradScaler + if IS_NPU_AVAILABLE: + from torch.npu.amp import GradScaler + else: + from torch.cuda.amp import GradScaler except ImportError: pass diff --git a/mmcv/utils/__init__.py b/mmcv/utils/__init__.py index 8bb5a8173d..6dbdc2e1a2 100644 --- a/mmcv/utils/__init__.py +++ b/mmcv/utils/__init__.py @@ -37,7 +37,7 @@ ] else: from .device_type import (IS_IPU_AVAILABLE, IS_MLU_AVAILABLE, - IS_MPS_AVAILABLE) + IS_MPS_AVAILABLE, IS_NPU_AVAILABLE) from .env import collect_env from .hub import load_url from .logging import get_logger, print_log @@ -77,5 +77,5 @@ 'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch', '_get_cuda_home', 'load_url', 'has_method', 'IS_CUDA_AVAILABLE', 'worker_init_fn', 'IS_MLU_AVAILABLE', 'IS_IPU_AVAILABLE', - 'IS_MPS_AVAILABLE', 'torch_meshgrid' + 'IS_MPS_AVAILABLE', 'IS_NPU_AVAILABLE', 'torch_meshgrid' ] diff --git a/mmcv/utils/device_type.py b/mmcv/utils/device_type.py index d42ff72e9f..cef966ac86 100644 --- a/mmcv/utils/device_type.py +++ b/mmcv/utils/device_type.py @@ -38,3 +38,16 @@ def is_mps_available() -> bool: IS_MPS_AVAILABLE = is_mps_available() + + +def is_npu_available() -> bool: + """Return True if npu devices exist.""" + try: + import torch + import torch_npu + return (hasattr(torch, 'npu') and torch_npu.npu.is_available()) + except Exception: + return False + + +IS_NPU_AVAILABLE = is_npu_available() diff --git a/tests/test_device/test_device_utils.py b/tests/test_device/test_device_utils.py index 6597efa5a3..11a34c97c1 100644 --- a/tests/test_device/test_device_utils.py +++ b/tests/test_device/test_device_utils.py @@ -1,11 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmcv.device import get_device -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE +from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE, + IS_NPU_AVAILABLE) def test_get_device(): current_device = get_device() - if IS_CUDA_AVAILABLE: + if IS_NPU_AVAILABLE: + assert current_device == 'npu' + elif IS_CUDA_AVAILABLE: assert current_device == 'cuda' elif IS_MLU_AVAILABLE: assert current_device == 'mlu' diff --git a/tests/test_device/test_functions.py b/tests/test_device/test_functions.py index dbbb8978b5..d0fb6a7ca4 100644 --- a/tests/test_device/test_functions.py +++ b/tests/test_device/test_functions.py @@ -3,7 +3,7 @@ import torch from mmcv.device._functions import Scatter, scatter -from mmcv.utils import IS_MLU_AVAILABLE, IS_MPS_AVAILABLE +from mmcv.utils import IS_MLU_AVAILABLE, IS_MPS_AVAILABLE, IS_NPU_AVAILABLE def test_scatter(): @@ -28,6 +28,17 @@ def test_scatter(): for input, output in zip(inputs, outputs): assert torch.allclose(input.to('mlu'), output) + # if the device is NPU, copy the input from CPU to NPU + if IS_NPU_AVAILABLE: + input = torch.zeros([1, 3, 3, 3]) + output = scatter(input=input, devices=[0]) + assert torch.allclose(input.to('npu'), output) + + inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] + outputs = scatter(input=inputs, devices=[0]) + for input, output in zip(inputs, outputs): + assert torch.allclose(input.to('npu'), output) + # if the device is MPS, copy the input from CPU to MPS if IS_MPS_AVAILABLE: input = torch.zeros([1, 3, 3, 3]) diff --git a/tests/test_device/test_npu/test_npu_parallel.py b/tests/test_device/test_npu/test_npu_parallel.py new file mode 100644 index 0000000000..ae5efa6b70 --- /dev/null +++ b/tests/test_device/test_npu/test_npu_parallel.py @@ -0,0 +1,37 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest.mock import MagicMock, patch + +import torch.nn as nn + +from mmcv.device.npu import NPUDataParallel, NPUDistributedDataParallel +from mmcv.parallel import is_module_wrapper +from mmcv.utils import IS_NPU_AVAILABLE + + +def mock(*args, **kwargs): + pass + + +@patch('torch.distributed._broadcast_coalesced', mock) +@patch('torch.distributed.broadcast', mock) +@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock) +def test_is_module_wrapper(): + + class Model(nn.Module): + + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(2, 2, 1) + + def forward(self, x): + return self.conv(x) + + model = Model() + assert not is_module_wrapper(model) + + if IS_NPU_AVAILABLE: + npudp = NPUDataParallel(model) + assert is_module_wrapper(npudp) + + npuddp = NPUDistributedDataParallel(model, process_group=MagicMock()) + assert is_module_wrapper(npuddp) From 716b3b3693e97217bf908c34e7ca1544963d9699 Mon Sep 17 00:00:00 2001 From: ckirchhoff2021 <515629648@qq.com> Date: Mon, 19 Sep 2022 15:19:00 +0800 Subject: [PATCH 02/67] add npu extension and focal loss adapter --- mmcv/ops/csrc/common/pytorch_npu_helper.hpp | 42 ++++++ mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp | 128 +++++++++++++++++++ mmcv/ops/focal_loss.py | 3 +- setup.py | 29 ++++- 4 files changed, 200 insertions(+), 2 deletions(-) create mode 100644 mmcv/ops/csrc/common/pytorch_npu_helper.hpp create mode 100644 mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp diff --git a/mmcv/ops/csrc/common/pytorch_npu_helper.hpp b/mmcv/ops/csrc/common/pytorch_npu_helper.hpp new file mode 100644 index 0000000000..99dbd5f252 --- /dev/null +++ b/mmcv/ops/csrc/common/pytorch_npu_helper.hpp @@ -0,0 +1,42 @@ +/****************************************************************************** +* Copyright (c) 2022 Huawei Technologies Co., Ltd +* All rights reserved. +* +* Licensed under the BSD 3-Clause License (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* https://opensource.org/licenses/BSD-3-Clause +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +******************************************************************************/ + +#ifndef PYTORCH_NPU_HELPER_HPP_ +#define PYTORCH_NPU_HELPER_HPP_ + +#include "pytorch_cpp_helper.hpp" +#include "pytorch_device_registry.hpp" + +#ifdef MMCV_WITH_NPU +#include +#include +#include +#define NPU_NAME_SPACE at_npu::native +#define REGISTER_NPU_IMPL(key, value) REGISTER_DEVICE_IMPL(key, XLA, value) +#define CHECK_NPU(x) \ + TORCH_CHECK(x.device().type() == at::kXLA, #x " must be a NPU tensor") +#else +// for torch 1.5.0 adapter only +#include +#include +#define NPU_NAME_SPACE at::native::npu +#define REGISTER_NPU_IMPL(key, value) REGISTER_DEVICE_IMPL(key, NPU, value); +#define CHECK_NPU(x) \ + TORCH_CHECK(x.device().type() == at::kNPU, #x " must be a NPU tensor") +#endif + +#endif // PYTORCH_NPU_HELPER_HPP_ \ No newline at end of file diff --git a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp new file mode 100644 index 0000000000..12e4a31318 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp @@ -0,0 +1,128 @@ +#include +#include "pytorch_npu_helper.hpp" + +using namespace NPU_NAME_SPACE; +using namespace std; + + +void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, + Tensor output, float gamma, float alpha) { + + at::Tensor target_y = at::reshape(target, input.sizes()); + target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); + at::Tensor grad_up = at::ones_like(input); + int64_t weight_size = weight.size(0); + at::Tensor weight_y = at::ones_like(input); + if(weight_size > 0) { + weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, input.sizes()); + } + + OpCommand cmd; + cmd.Name("SigmoidFocalLoss") + .Input(input) + .Input(target_y) + .Input(grad_up) + .Input(weight_y) + .Output(grad_input) + .Attr("gamma", gamma) + .Attr("alpha", alpha) + .Attr("reduction", "none") + .Run(); +} + +void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, + Tensor output, float gamma, float alpha); + +void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, + Tensor grad_input, float gamma, float alpha) { + + at::Tensor target_y = at::reshape(target, input.sizes()); + target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); + at::Tensor grad_up = at::ones_like(input); + int64_t weight_size = weight.size(0); + at::Tensor weight_y = at::ones_like(input); + if(weight_size > 0) { + weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, input.sizes()); + } + + OpCommand cmd; + cmd.Name("SigmoidFocalLossGrad") + .Input(input) + .Input(target_y) + .Input(grad_up) + .Input(weight_y) + .Output(grad_input) + .Attr("gamma", gamma) + .Attr("alpha", alpha) + .Attr("reduction", "none") + .Run(); +} + +void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target, Tensor weight, + Tensor grad_input, float gamma, float alpha); + +void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, + Tensor output, float gamma, float alpha) { + + int64_t n_class = input.size(1); + at::Tensor target_y = at_npu::native::NPUNativeFunctions::one_hot(target, n_class); + target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); + at::Tensor grad_up = at::ones_like(input); + int64_t weight_size = weight.size(0); + at::Tensor weight_y = at::ones_like(input); + if(weight_size > 0) { + weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, input.sizes()); + } + + OpCommand cmd; + cmd.Name("SoftmaxFocalLoss") + .Input(input) + .Input(target_y) + .Input(grad_up) + .Input(weight_y) + .Output(grad_input) + .Attr("gamma", gamma) + .Attr("alpha", alpha) + .Attr("reduction", "none") + .Run(); +} + +void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, + Tensor grad_input, float gamma, float alpha); + +void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, Tensor buff, + Tensor grad_input, float gamma, float alpha) { + + int64_t n_class = input.size(1); + at::Tensor target_y = at_npu::native::NPUNativeFunctions::one_hot(target, n_class); + target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); + at::Tensor grad_up = at::ones_like(input); + int64_t weight_size = weight.size(0); + at::Tensor weight_y = at::ones_like(input); + if(weight_size > 0) { + weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, input.sizes()); + } + + OpCommand cmd; + cmd.Name("SoftmaxFocalLossGrad") + .Input(input) + .Input(target_y) + .Input(grad_up) + .Input(weight_y) + .Output(grad_input) + .Attr("gamma", gamma) + .Attr("alpha", alpha) + .Attr("reduction", "none") + .Run(); +} + +void softmax_focal_loss_backward_impl(Tensor input, Tensor target, Tensor weight, Tensor buff, + Tensor grad_input, float gamma, float alpha); + +REGISTER_NPU_IMPL(sigmoid_focal_loss_forward_impl, sigmoid_focal_loss_forward_npu); + +REGISTER_NPU_IMPL(sigmoid_focal_loss_backward_impl, sigmoid_focal_loss_backward_npu); + +REGISTER_NPU_IMPL(softmax_focal_loss_forward_impl, softmax_focal_loss_forward_npu); + +REGISTER_NPU_IMPL(softmax_focal_loss_backward_impl, softmax_focal_loss_backward_npu); diff --git a/mmcv/ops/focal_loss.py b/mmcv/ops/focal_loss.py index 3b203fc15b..d4cf07138c 100644 --- a/mmcv/ops/focal_loss.py +++ b/mmcv/ops/focal_loss.py @@ -143,7 +143,8 @@ def forward(ctx, weight: Optional[torch.Tensor] = None, reduction='mean') -> torch.Tensor: - assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor)) + assert isinstance( + target, (torch.Tensor, torch.LongTensor, torch.cuda.LongTensor)) assert input.dim() == 2 assert target.dim() == 1 assert input.size(0) == target.size(0) diff --git a/setup.py b/setup.py index 274c13de33..f0fb05af95 100644 --- a/setup.py +++ b/setup.py @@ -280,7 +280,7 @@ def get_extensions(): if is_rocm_pytorch or torch.cuda.is_available() or os.getenv( 'FORCE_CUDA', '0') == '1': if is_rocm_pytorch: - define_macros += [('HIP_DIFF', None)] + define_macros += [('MMCV_WITH_HIP', None)] define_macros += [('MMCV_WITH_CUDA', None)] cuda_args = os.getenv('MMCV_CUDA_ARGS') extra_compile_args['nvcc'] = [cuda_args] if cuda_args else [] @@ -289,6 +289,7 @@ def get_extensions(): glob.glob('./mmcv/ops/csrc/pytorch/cuda/*.cu') + \ glob.glob('./mmcv/ops/csrc/pytorch/cuda/*.cpp') extension = CUDAExtension + include_dirs.append(os.path.abspath('./mmcv/ops/csrc/pytorch')) include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common')) include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/cuda')) elif (hasattr(torch, 'is_mlu_available') and @@ -329,6 +330,32 @@ def get_extensions(): extension = CppExtension include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common')) include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/mps')) + elif (os.getenv('FORCE_NPU', '0') == '1'): + print(f'Compiling {ext_name} only with CPU and NPU') + try: + has_npu = torch.npu.is_available() + print('torch_npu version 1.5 is available. ', has_npu) + extension = CppExtension + except: + try: + import torch_npu + from torch_npu.utils.cpp_extension import NpuExtension + has_npu = torch_npu.npu.is_available() + print('torch_npu version 1.8 is available.: ', has_npu) + define_macros += [('MMCV_WITH_NPU', None)] + extension = NpuExtension + except: + print('can not find any torch_npu') + return extensions + + # src + op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \ + glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') + \ + glob.glob('./mmcv/ops/csrc/common/npu/*.cpp') + \ + glob.glob('./mmcv/ops/csrc/pytorch/npu/*.cpp') + + include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common')) + include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/npu')) else: print(f'Compiling {ext_name} only with CPU') op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \ From 6e53b3fa1d4a7bf3667a2b21532e88fb6730dcad Mon Sep 17 00:00:00 2001 From: ckirchhoff2021 <515629648@qq.com> Date: Tue, 20 Sep 2022 09:58:22 +0800 Subject: [PATCH 03/67] clean code --- setup.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index f0fb05af95..de32aa571b 100644 --- a/setup.py +++ b/setup.py @@ -336,24 +336,24 @@ def get_extensions(): has_npu = torch.npu.is_available() print('torch_npu version 1.5 is available. ', has_npu) extension = CppExtension - except: + except Exception: try: import torch_npu - from torch_npu.utils.cpp_extension import NpuExtension + from torch_npu.utils.cpp_extension import NpuExtension has_npu = torch_npu.npu.is_available() print('torch_npu version 1.8 is available.: ', has_npu) define_macros += [('MMCV_WITH_NPU', None)] extension = NpuExtension - except: + except Exception: print('can not find any torch_npu') return extensions - + # src op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \ glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') + \ glob.glob('./mmcv/ops/csrc/common/npu/*.cpp') + \ glob.glob('./mmcv/ops/csrc/pytorch/npu/*.cpp') - + include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common')) include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/npu')) else: From 08f0a16a8ae227bd23866d7b6f87bc415d090758 Mon Sep 17 00:00:00 2001 From: ckirchhoff2021 <515629648@qq.com> Date: Tue, 20 Sep 2022 10:17:13 +0800 Subject: [PATCH 04/67] clean code --- mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp index 12e4a31318..2ab07ba931 100644 --- a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp @@ -5,7 +5,7 @@ using namespace NPU_NAME_SPACE; using namespace std; -void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, +void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, Tensor output, float gamma, float alpha) { at::Tensor target_y = at::reshape(target, input.sizes()); @@ -30,10 +30,10 @@ void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, .Run(); } -void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, +void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, Tensor output, float gamma, float alpha); -void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, +void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, Tensor grad_input, float gamma, float alpha) { at::Tensor target_y = at::reshape(target, input.sizes()); @@ -58,10 +58,10 @@ void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, .Run(); } -void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target, Tensor weight, +void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target, Tensor weight, Tensor grad_input, float gamma, float alpha); -void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, +void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, Tensor output, float gamma, float alpha) { int64_t n_class = input.size(1); @@ -73,7 +73,7 @@ void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, if(weight_size > 0) { weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, input.sizes()); } - + OpCommand cmd; cmd.Name("SoftmaxFocalLoss") .Input(input) @@ -87,7 +87,7 @@ void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, .Run(); } -void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, +void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, Tensor grad_input, float gamma, float alpha); void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, Tensor buff, @@ -102,7 +102,7 @@ void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, if(weight_size > 0) { weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, input.sizes()); } - + OpCommand cmd; cmd.Name("SoftmaxFocalLossGrad") .Input(input) From da659cb1a85cd7c91bbcf3a87a4ef6fa42d3b8af Mon Sep 17 00:00:00 2001 From: ckirchhoff2021 <515629648@qq.com> Date: Tue, 20 Sep 2022 10:47:40 +0800 Subject: [PATCH 05/67] clean code --- mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp index 2ab07ba931..d3ece95975 100644 --- a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp @@ -4,7 +4,6 @@ using namespace NPU_NAME_SPACE; using namespace std; - void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, Tensor output, float gamma, float alpha) { From 448476e8e37c9bacdfe6e222997aa7087cdf1167 Mon Sep 17 00:00:00 2001 From: ckirchhoff2021 <515629648@qq.com> Date: Tue, 20 Sep 2022 14:08:54 +0800 Subject: [PATCH 06/67] clean code --- mmcv/ops/csrc/common/pytorch_npu_helper.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mmcv/ops/csrc/common/pytorch_npu_helper.hpp b/mmcv/ops/csrc/common/pytorch_npu_helper.hpp index 99dbd5f252..8c769e4047 100644 --- a/mmcv/ops/csrc/common/pytorch_npu_helper.hpp +++ b/mmcv/ops/csrc/common/pytorch_npu_helper.hpp @@ -30,6 +30,7 @@ #define CHECK_NPU(x) \ TORCH_CHECK(x.device().type() == at::kXLA, #x " must be a NPU tensor") #else + // for torch 1.5.0 adapter only #include #include @@ -39,4 +40,4 @@ TORCH_CHECK(x.device().type() == at::kNPU, #x " must be a NPU tensor") #endif -#endif // PYTORCH_NPU_HELPER_HPP_ \ No newline at end of file +#endif // PYTORCH_NPU_HELPER_HPP_ From afbd35126b472eedefb134aa34557e1ced966fce Mon Sep 17 00:00:00 2001 From: wangjiangben-hw <111729245+wangjiangben-hw@users.noreply.github.com> Date: Tue, 20 Sep 2022 15:21:09 +0800 Subject: [PATCH 07/67] fix autocast bugs on npu (#2273) fix autocast bugs on npu (#2273) --- mmcv/runner/fp16_utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mmcv/runner/fp16_utils.py b/mmcv/runner/fp16_utils.py index 4674d27a44..2c349b64fa 100644 --- a/mmcv/runner/fp16_utils.py +++ b/mmcv/runner/fp16_utils.py @@ -10,7 +10,7 @@ import torch.nn as nn from torch.nn.parameter import Parameter -from mmcv.utils import TORCH_VERSION, digit_version +from mmcv.utils import IS_NPU_AVAILABLE, TORCH_VERSION, digit_version from .dist_utils import allreduce_grads as _allreduce_grads try: @@ -18,7 +18,10 @@ # and used; otherwise, auto fp16 will adopt mmcv's implementation. # Note that when PyTorch >= 1.6.0, we still cast tensor types to fp16 # manually, so the behavior may not be consistent with real amp. - from torch.cuda.amp import autocast + if IS_NPU_AVAILABLE: + from torch.npu.amp import autocast + else: + from torch.cuda.amp import autocast except ImportError: pass From 26f35e0a50dce06b71f1c7524d09f84f7a1f6fe8 Mon Sep 17 00:00:00 2001 From: ckirchhoff2021 <515629648@qq.com> Date: Tue, 20 Sep 2022 17:26:19 +0800 Subject: [PATCH 08/67] code format --- mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp | 232 ++++++++++--------- 1 file changed, 123 insertions(+), 109 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp index d3ece95975..6d20988ed2 100644 --- a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp @@ -1,127 +1,141 @@ -#include #include "pytorch_npu_helper.hpp" using namespace NPU_NAME_SPACE; -using namespace std; void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, Tensor output, float gamma, float alpha) { - - at::Tensor target_y = at::reshape(target, input.sizes()); - target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); - at::Tensor grad_up = at::ones_like(input); - int64_t weight_size = weight.size(0); - at::Tensor weight_y = at::ones_like(input); - if(weight_size > 0) { - weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, input.sizes()); - } - - OpCommand cmd; - cmd.Name("SigmoidFocalLoss") - .Input(input) - .Input(target_y) - .Input(grad_up) - .Input(weight_y) - .Output(grad_input) - .Attr("gamma", gamma) - .Attr("alpha", alpha) - .Attr("reduction", "none") - .Run(); + at::Tensor target_y = at::reshape(target, input.sizes()); + target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast( + target_y, at::kInt); + at::Tensor grad_up = at::ones_like(input); + int64_t weight_size = weight.size(0); + at::Tensor weight_y = at::ones_like(input); + if(weight_size > 0) { + weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast( + weight, input.sizes()); + } + OpCommand cmd; + cmd.Name("SigmoidFocalLoss") + .Input(input) + .Input(target_y) + .Input(grad_up) + .Input(weight_y) + .Output(grad_input) + .Attr("gamma", gamma) + .Attr("alpha", alpha) + .Attr("reduction", "none") + .Run(); } -void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, - Tensor output, float gamma, float alpha); - -void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, - Tensor grad_input, float gamma, float alpha) { - - at::Tensor target_y = at::reshape(target, input.sizes()); - target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); - at::Tensor grad_up = at::ones_like(input); - int64_t weight_size = weight.size(0); - at::Tensor weight_y = at::ones_like(input); - if(weight_size > 0) { - weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, input.sizes()); - } - - OpCommand cmd; - cmd.Name("SigmoidFocalLossGrad") - .Input(input) - .Input(target_y) - .Input(grad_up) - .Input(weight_y) - .Output(grad_input) - .Attr("gamma", gamma) - .Attr("alpha", alpha) - .Attr("reduction", "none") - .Run(); +void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, + Tensor weight, Tensor output, + float gamma, float alpha); + +void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, + Tensor weight, Tensor grad_input, + float gamma, float alpha) { + at::Tensor target_y = at::reshape(target, input.sizes()); + target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast( + target_y, at::kInt); + at::Tensor grad_up = at::ones_like(input); + int64_t weight_size = weight.size(0); + at::Tensor weight_y = at::ones_like(input); + if(weight_size > 0) { + weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast( + weight, input.sizes()); + } + OpCommand cmd; + cmd.Name("SigmoidFocalLossGrad") + .Input(input) + .Input(target_y) + .Input(grad_up) + .Input(weight_y) + .Output(grad_input) + .Attr("gamma", gamma) + .Attr("alpha", alpha) + .Attr("reduction", "none") + .Run(); } -void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target, Tensor weight, - Tensor grad_input, float gamma, float alpha); - -void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, - Tensor output, float gamma, float alpha) { - - int64_t n_class = input.size(1); - at::Tensor target_y = at_npu::native::NPUNativeFunctions::one_hot(target, n_class); - target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); - at::Tensor grad_up = at::ones_like(input); - int64_t weight_size = weight.size(0); - at::Tensor weight_y = at::ones_like(input); - if(weight_size > 0) { - weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, input.sizes()); - } - - OpCommand cmd; - cmd.Name("SoftmaxFocalLoss") - .Input(input) - .Input(target_y) - .Input(grad_up) - .Input(weight_y) - .Output(grad_input) - .Attr("gamma", gamma) - .Attr("alpha", alpha) - .Attr("reduction", "none") - .Run(); +void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target, + Tensor weight, Tensor grad_input, + float gamma, float alpha); + +void softmax_focal_loss_forward_npu(Tensor input, Tensor target, + Tensor weight, Tensor output, + float gamma, float alpha) { + int64_t n_class = input.size(1); + at::Tensor target_y = at_npu::native::NPUNativeFunctions::one_hot( + target, n_class); + target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast( + target_y, at::kInt); + at::Tensor grad_up = at::ones_like(input); + int64_t weight_size = weight.size(0); + at::Tensor weight_y = at::ones_like(input); + if(weight_size > 0) { + weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast( + weight, input.sizes()); + } + OpCommand cmd; + cmd.Name("SoftmaxFocalLoss") + .Input(input) + .Input(target_y) + .Input(grad_up) + .Input(weight_y) + .Output(grad_input) + .Attr("gamma", gamma) + .Attr("alpha", alpha) + .Attr("reduction", "none") + .Run(); } -void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, - Tensor grad_input, float gamma, float alpha); - -void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, Tensor buff, - Tensor grad_input, float gamma, float alpha) { - - int64_t n_class = input.size(1); - at::Tensor target_y = at_npu::native::NPUNativeFunctions::one_hot(target, n_class); - target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); - at::Tensor grad_up = at::ones_like(input); - int64_t weight_size = weight.size(0); - at::Tensor weight_y = at::ones_like(input); - if(weight_size > 0) { - weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, input.sizes()); - } - - OpCommand cmd; - cmd.Name("SoftmaxFocalLossGrad") - .Input(input) - .Input(target_y) - .Input(grad_up) - .Input(weight_y) - .Output(grad_input) - .Attr("gamma", gamma) - .Attr("alpha", alpha) - .Attr("reduction", "none") - .Run(); +void softmax_focal_loss_forward_impl(Tensor input, Tensor target, + Tensor weight, Tensor grad_input, + float gamma, float alpha); + +void softmax_focal_loss_backward_npu(Tensor input, Tensor target, + Tensor weight, Tensor buff, + Tensor grad_input, float gamma, + float alpha) { + int64_t n_class = input.size(1); + at::Tensor target_y = at_npu::native::NPUNativeFunctions::one_hot( + target, n_class); + target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast( + target_y, at::kInt); + at::Tensor grad_up = at::ones_like(input); + int64_t weight_size = weight.size(0); + at::Tensor weight_y = at::ones_like(input); + if(weight_size > 0) { + weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast( + weight, input.sizes()); + } + + OpCommand cmd; + cmd.Name("SoftmaxFocalLossGrad") + .Input(input) + .Input(target_y) + .Input(grad_up) + .Input(weight_y) + .Output(grad_input) + .Attr("gamma", gamma) + .Attr("alpha", alpha) + .Attr("reduction", "none") + .Run(); } -void softmax_focal_loss_backward_impl(Tensor input, Tensor target, Tensor weight, Tensor buff, - Tensor grad_input, float gamma, float alpha); +void softmax_focal_loss_backward_impl(Tensor input, Tensor target, + Tensor weight, Tensor buff, + Tensor grad_input, float gamma, + float alpha); -REGISTER_NPU_IMPL(sigmoid_focal_loss_forward_impl, sigmoid_focal_loss_forward_npu); +REGISTER_NPU_IMPL(sigmoid_focal_loss_forward_impl, + sigmoid_focal_loss_forward_npu); -REGISTER_NPU_IMPL(sigmoid_focal_loss_backward_impl, sigmoid_focal_loss_backward_npu); +REGISTER_NPU_IMPL(sigmoid_focal_loss_backward_impl, + sigmoid_focal_loss_backward_npu); -REGISTER_NPU_IMPL(softmax_focal_loss_forward_impl, softmax_focal_loss_forward_npu); +REGISTER_NPU_IMPL(softmax_focal_loss_forward_impl, + softmax_focal_loss_forward_npu); -REGISTER_NPU_IMPL(softmax_focal_loss_backward_impl, softmax_focal_loss_backward_npu); +REGISTER_NPU_IMPL(softmax_focal_loss_backward_impl, + softmax_focal_loss_backward_npu); From 58618ac9f7da2fbcaf2b1e848b5866b6fdc7861d Mon Sep 17 00:00:00 2001 From: ckirchhoff2021 <515629648@qq.com> Date: Tue, 20 Sep 2022 21:10:09 +0800 Subject: [PATCH 09/67] code format --- mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp | 77 ++++++++++---------- 1 file changed, 37 insertions(+), 40 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp index 6d20988ed2..982544bac0 100644 --- a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp @@ -5,14 +5,14 @@ using namespace NPU_NAME_SPACE; void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, Tensor output, float gamma, float alpha) { at::Tensor target_y = at::reshape(target, input.sizes()); - target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast( - target_y, at::kInt); + target_y = + at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); at::Tensor grad_up = at::ones_like(input); int64_t weight_size = weight.size(0); at::Tensor weight_y = at::ones_like(input); - if(weight_size > 0) { - weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast( - weight, input.sizes()); + if (weight_size > 0) { + weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, + input.sizes()); } OpCommand cmd; cmd.Name("SigmoidFocalLoss") @@ -27,22 +27,21 @@ void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, .Run(); } -void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, - Tensor weight, Tensor output, - float gamma, float alpha); +void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, + Tensor output, float gamma, float alpha); -void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, - Tensor weight, Tensor grad_input, - float gamma, float alpha) { +void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, + Tensor grad_input, float gamma, + float alpha) { at::Tensor target_y = at::reshape(target, input.sizes()); - target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast( - target_y, at::kInt); + target_y = + at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); at::Tensor grad_up = at::ones_like(input); int64_t weight_size = weight.size(0); at::Tensor weight_y = at::ones_like(input); - if(weight_size > 0) { - weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast( - weight, input.sizes()); + if (weight_size > 0) { + weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, + input.sizes()); } OpCommand cmd; cmd.Name("SigmoidFocalLossGrad") @@ -61,20 +60,19 @@ void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target, Tensor weight, Tensor grad_input, float gamma, float alpha); -void softmax_focal_loss_forward_npu(Tensor input, Tensor target, - Tensor weight, Tensor output, - float gamma, float alpha) { +void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, + Tensor output, float gamma, float alpha) { int64_t n_class = input.size(1); - at::Tensor target_y = at_npu::native::NPUNativeFunctions::one_hot( - target, n_class); - target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast( - target_y, at::kInt); + at::Tensor target_y = + at_npu::native::NPUNativeFunctions::one_hot(target, n_class); + target_y = + at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); at::Tensor grad_up = at::ones_like(input); int64_t weight_size = weight.size(0); at::Tensor weight_y = at::ones_like(input); - if(weight_size > 0) { - weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast( - weight, input.sizes()); + if (weight_size > 0) { + weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, + input.sizes()); } OpCommand cmd; cmd.Name("SoftmaxFocalLoss") @@ -89,25 +87,24 @@ void softmax_focal_loss_forward_npu(Tensor input, Tensor target, .Run(); } -void softmax_focal_loss_forward_impl(Tensor input, Tensor target, - Tensor weight, Tensor grad_input, - float gamma, float alpha); - -void softmax_focal_loss_backward_npu(Tensor input, Tensor target, - Tensor weight, Tensor buff, +void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, Tensor grad_input, float gamma, - float alpha) { + float alpha); + +void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, + Tensor buff, Tensor grad_input, + float gamma, float alpha) { int64_t n_class = input.size(1); - at::Tensor target_y = at_npu::native::NPUNativeFunctions::one_hot( - target, n_class); - target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast( - target_y, at::kInt); + at::Tensor target_y = + at_npu::native::NPUNativeFunctions::one_hot(target, n_class); + target_y = + at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); at::Tensor grad_up = at::ones_like(input); int64_t weight_size = weight.size(0); at::Tensor weight_y = at::ones_like(input); - if(weight_size > 0) { - weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast( - weight, input.sizes()); + if (weight_size > 0) { + weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, + input.sizes()); } OpCommand cmd; From d75dcb10426703443531c76a4cb31bfa1d3d34cb Mon Sep 17 00:00:00 2001 From: ckirchhoff2021 <515629648@qq.com> Date: Wed, 21 Sep 2022 15:46:06 +0800 Subject: [PATCH 10/67] code format --- mmcv/ops/csrc/common/pytorch_npu_helper.hpp | 14 +++----------- setup.py | 20 +++++--------------- 2 files changed, 8 insertions(+), 26 deletions(-) diff --git a/mmcv/ops/csrc/common/pytorch_npu_helper.hpp b/mmcv/ops/csrc/common/pytorch_npu_helper.hpp index 8c769e4047..54c80c7772 100644 --- a/mmcv/ops/csrc/common/pytorch_npu_helper.hpp +++ b/mmcv/ops/csrc/common/pytorch_npu_helper.hpp @@ -21,23 +21,15 @@ #include "pytorch_cpp_helper.hpp" #include "pytorch_device_registry.hpp" -#ifdef MMCV_WITH_NPU #include #include #include + #define NPU_NAME_SPACE at_npu::native + #define REGISTER_NPU_IMPL(key, value) REGISTER_DEVICE_IMPL(key, XLA, value) -#define CHECK_NPU(x) \ - TORCH_CHECK(x.device().type() == at::kXLA, #x " must be a NPU tensor") -#else -// for torch 1.5.0 adapter only -#include -#include -#define NPU_NAME_SPACE at::native::npu -#define REGISTER_NPU_IMPL(key, value) REGISTER_DEVICE_IMPL(key, NPU, value); #define CHECK_NPU(x) \ - TORCH_CHECK(x.device().type() == at::kNPU, #x " must be a NPU tensor") -#endif + TORCH_CHECK(x.device().type() == at::kXLA, #x " must be a NPU tensor") #endif // PYTORCH_NPU_HELPER_HPP_ diff --git a/setup.py b/setup.py index de32aa571b..46cd5a302f 100644 --- a/setup.py +++ b/setup.py @@ -333,27 +333,17 @@ def get_extensions(): elif (os.getenv('FORCE_NPU', '0') == '1'): print(f'Compiling {ext_name} only with CPU and NPU') try: - has_npu = torch.npu.is_available() - print('torch_npu version 1.5 is available. ', has_npu) - extension = CppExtension + from torch_npu.utils.cpp_extension import NpuExtension + define_macros += [('MMCV_WITH_NPU', None)] + extension = NpuExtension except Exception: - try: - import torch_npu - from torch_npu.utils.cpp_extension import NpuExtension - has_npu = torch_npu.npu.is_available() - print('torch_npu version 1.8 is available.: ', has_npu) - define_macros += [('MMCV_WITH_NPU', None)] - extension = NpuExtension - except Exception: - print('can not find any torch_npu') - return extensions - + print('can not find any torch_npu') + return extensions # src op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \ glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') + \ glob.glob('./mmcv/ops/csrc/common/npu/*.cpp') + \ glob.glob('./mmcv/ops/csrc/pytorch/npu/*.cpp') - include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common')) include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/npu')) else: From 7af2a6febe00051b8ad4618804749a9acb1536c5 Mon Sep 17 00:00:00 2001 From: ckirchhoff2021 <515629648@qq.com> Date: Wed, 21 Sep 2022 17:17:38 +0800 Subject: [PATCH 11/67] bug fix --- mmcv/ops/csrc/common/pytorch_npu_helper.hpp | 36 ++++++++++---------- mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp | 8 ++--- 2 files changed, 20 insertions(+), 24 deletions(-) diff --git a/mmcv/ops/csrc/common/pytorch_npu_helper.hpp b/mmcv/ops/csrc/common/pytorch_npu_helper.hpp index 54c80c7772..ef95716e9e 100644 --- a/mmcv/ops/csrc/common/pytorch_npu_helper.hpp +++ b/mmcv/ops/csrc/common/pytorch_npu_helper.hpp @@ -1,30 +1,30 @@ /****************************************************************************** -* Copyright (c) 2022 Huawei Technologies Co., Ltd -* All rights reserved. -* -* Licensed under the BSD 3-Clause License (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://opensource.org/licenses/BSD-3-Clause -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -******************************************************************************/ + * Copyright (c) 2022 Huawei Technologies Co., Ltd + * All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ #ifndef PYTORCH_NPU_HELPER_HPP_ #define PYTORCH_NPU_HELPER_HPP_ -#include "pytorch_cpp_helper.hpp" -#include "pytorch_device_registry.hpp" - #include #include #include +#include "pytorch_cpp_helper.hpp" +#include "pytorch_device_registry.hpp" + #define NPU_NAME_SPACE at_npu::native #define REGISTER_NPU_IMPL(key, value) REGISTER_DEVICE_IMPL(key, XLA, value) diff --git a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp index 982544bac0..bd82824689 100644 --- a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp @@ -7,7 +7,6 @@ void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, at::Tensor target_y = at::reshape(target, input.sizes()); target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); - at::Tensor grad_up = at::ones_like(input); int64_t weight_size = weight.size(0); at::Tensor weight_y = at::ones_like(input); if (weight_size > 0) { @@ -18,9 +17,8 @@ void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, cmd.Name("SigmoidFocalLoss") .Input(input) .Input(target_y) - .Input(grad_up) .Input(weight_y) - .Output(grad_input) + .Output(output) .Attr("gamma", gamma) .Attr("alpha", alpha) .Attr("reduction", "none") @@ -67,7 +65,6 @@ void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, at_npu::native::NPUNativeFunctions::one_hot(target, n_class); target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); - at::Tensor grad_up = at::ones_like(input); int64_t weight_size = weight.size(0); at::Tensor weight_y = at::ones_like(input); if (weight_size > 0) { @@ -78,9 +75,8 @@ void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, cmd.Name("SoftmaxFocalLoss") .Input(input) .Input(target_y) - .Input(grad_up) .Input(weight_y) - .Output(grad_input) + .Output(output) .Attr("gamma", gamma) .Attr("alpha", alpha) .Attr("reduction", "none") From 268ff0e1766b89e4a82e7b7db8819c07858dc418 Mon Sep 17 00:00:00 2001 From: ckirchhoff2021 <515629648@qq.com> Date: Wed, 21 Sep 2022 18:20:18 +0800 Subject: [PATCH 12/67] pytorch_npu_helper.hpp clean code --- mmcv/ops/csrc/common/pytorch_npu_helper.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmcv/ops/csrc/common/pytorch_npu_helper.hpp b/mmcv/ops/csrc/common/pytorch_npu_helper.hpp index ef95716e9e..88607d23b3 100644 --- a/mmcv/ops/csrc/common/pytorch_npu_helper.hpp +++ b/mmcv/ops/csrc/common/pytorch_npu_helper.hpp @@ -18,9 +18,9 @@ #ifndef PYTORCH_NPU_HELPER_HPP_ #define PYTORCH_NPU_HELPER_HPP_ +#include #include #include -#include #include "pytorch_cpp_helper.hpp" #include "pytorch_device_registry.hpp" From a0435417d9dbcdf9ec3b1c4ffb21646fc9b90339 Mon Sep 17 00:00:00 2001 From: wangjiangben-hw <111729245+wangjiangben-hw@users.noreply.github.com> Date: Thu, 29 Sep 2022 16:48:36 +0800 Subject: [PATCH 13/67] Npu dev (#2306) * fix autocast bugs on npu * using scatter_kwargs in mmcv.device.scatter_gather --- mmcv/device/npu/_functions.py | 25 ------------- mmcv/device/npu/data_parallel.py | 4 +-- mmcv/device/npu/distributed.py | 2 +- mmcv/device/npu/scatter_gather.py | 60 ------------------------------- 4 files changed, 3 insertions(+), 88 deletions(-) delete mode 100644 mmcv/device/npu/_functions.py delete mode 100644 mmcv/device/npu/scatter_gather.py diff --git a/mmcv/device/npu/_functions.py b/mmcv/device/npu/_functions.py deleted file mode 100644 index e022e59cd2..0000000000 --- a/mmcv/device/npu/_functions.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright Huawei Technologies Co., Ltd. All rights reserved. -# Copyright (c) OpenMMLab. All rights reserved. -from typing import List, Union - -import torch - - -def scatter(input: Union[List, torch.Tensor], devices: List) -> List: - """scatter copies tensor to NPU directly.""" - if isinstance(input, list): - outputs = [scatter(_input, devices) for _input in input] - return outputs - elif isinstance(input, torch.Tensor): - output = input.contiguous() - return output.to('npu') if devices != [-1] else output - else: - raise Exception(f'Unknown type {type(input)}.') - - -class Scatter: - - @staticmethod - def forward(target_npus, input): - outputs = scatter(input, target_npus) - return tuple(outputs) if isinstance(outputs, list) else (outputs, ) diff --git a/mmcv/device/npu/data_parallel.py b/mmcv/device/npu/data_parallel.py index 31352398b5..c107b2240e 100644 --- a/mmcv/device/npu/data_parallel.py +++ b/mmcv/device/npu/data_parallel.py @@ -5,8 +5,8 @@ import torch +from mmcv.device.scatter_gather import scatter_kwargs from mmcv.parallel import MMDataParallel -from .scatter_gather import scatter_kwargs def _check_balance(*args, **kwargs): @@ -19,7 +19,7 @@ def _check_balance(*args, **kwargs): # _check_balance function in DataParallel to make initialization pass. for m in sys.modules: if m.startswith('torch') or 'mmcv' in m: - if getattr(sys.modules[m], '_check_balance', None) is not None: + if hasattr(sys.modules[m], '_check_balance'): setattr(sys.modules[m], '_check_balance', _check_balance) diff --git a/mmcv/device/npu/distributed.py b/mmcv/device/npu/distributed.py index f9e9ce46e8..5f0fe55c8d 100644 --- a/mmcv/device/npu/distributed.py +++ b/mmcv/device/npu/distributed.py @@ -1,8 +1,8 @@ # Copyright Huawei Technologies Co., Ltd. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved. +from mmcv.device.scatter_gather import scatter_kwargs from mmcv.parallel import MMDistributedDataParallel -from .scatter_gather import scatter_kwargs class NPUDistributedDataParallel(MMDistributedDataParallel): diff --git a/mmcv/device/npu/scatter_gather.py b/mmcv/device/npu/scatter_gather.py deleted file mode 100644 index 9b277d22f7..0000000000 --- a/mmcv/device/npu/scatter_gather.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright Huawei Technologies Co., Ltd. All rights reserved. -# Copyright (c) OpenMMLab. All rights reserved. -import torch - -from mmcv.parallel.data_container import DataContainer -from ._functions import Scatter - - -def scatter(inputs, target_npus, dim=0): - """Scatter inputs to target npu. - - The only difference from original :func:`scatter` is to add support for - :type:`~mmcv.parallel.DataContainer`. - """ - - def scatter_map(obj): - if isinstance(obj, torch.Tensor): - if target_npus != [-1]: - obj = obj.to('npu') - return [obj] - else: - # for CPU inference we use self-implemented scatter - return Scatter.forward(target_npus, obj) - if isinstance(obj, DataContainer): - if obj.cpu_only: - return obj.data - else: - return Scatter.forward(target_npus, obj.data) - if isinstance(obj, tuple) and len(obj) > 0: - return list(zip(*map(scatter_map, obj))) - if isinstance(obj, list) and len(obj) > 0: - out = list(map(list, zip(*map(scatter_map, obj)))) - return out - if isinstance(obj, dict) and len(obj) > 0: - out = list(map(type(obj), zip(*map(scatter_map, obj.items())))) - return out - return [obj for targets in target_npus] - - # After scatter_map is called, a scatter_map cell will exist. This cell - # has a reference to the actual function scatter_map, which has references - # to a closure that has a reference to the scatter_map cell (because the - # fn is recursive). To avoid this reference cycle, we set the function to - # None, clearing the cell - try: - return scatter_map(inputs) - finally: - scatter_map = None - - -def scatter_kwargs(inputs, kwargs, target_npus, dim=0): - """Scatter with support for kwargs dictionary.""" - inputs = scatter(inputs, target_npus, dim) if inputs else [] - kwargs = scatter(kwargs, target_npus, dim) if kwargs else [] - if len(inputs) < len(kwargs): - inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) - elif len(kwargs) < len(inputs): - kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) - inputs = tuple(inputs) - kwargs = tuple(kwargs) - return inputs, kwargs From 90fc3dcd8d384aaaaa293c5f4b9a427fe63bc6be Mon Sep 17 00:00:00 2001 From: wangjiangben-hw <111729245+wangjiangben-hw@users.noreply.github.com> Date: Thu, 29 Sep 2022 16:52:10 +0800 Subject: [PATCH 14/67] raise ImportError when compile with npu --- setup.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 46cd5a302f..98ba6514bc 100644 --- a/setup.py +++ b/setup.py @@ -337,8 +337,7 @@ def get_extensions(): define_macros += [('MMCV_WITH_NPU', None)] extension = NpuExtension except Exception: - print('can not find any torch_npu') - return extensions + raise ImportError('can not find any torch_npu') # src op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \ glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') + \ From 6ffdb57d15222637df461201ef102b46bb7e0d4e Mon Sep 17 00:00:00 2001 From: wangjiangben-hw <111729245+wangjiangben-hw@users.noreply.github.com> Date: Thu, 29 Sep 2022 20:11:16 +0800 Subject: [PATCH 15/67] add npu test case (#2307) * add npu test case --- tests/test_ops/test_focal_loss.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/test_ops/test_focal_loss.py b/tests/test_ops/test_focal_loss.py index 316f58469d..f7b012bef2 100644 --- a/tests/test_ops/test_focal_loss.py +++ b/tests/test_ops/test_focal_loss.py @@ -3,7 +3,7 @@ import pytest import torch -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE _USING_PARROTS = True try: @@ -130,6 +130,10 @@ def test_softmax_half(self): self._test_softmax(dtype=torch.half) @pytest.mark.parametrize('device', [ + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')), pytest.param( 'cuda', marks=pytest.mark.skipif( @@ -143,6 +147,10 @@ def test_sigmoid_float(self, device): self._test_sigmoid(device=device, dtype=torch.float) @pytest.mark.parametrize('device', [ + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires NPU support')), pytest.param( 'cuda', marks=pytest.mark.skipif( From 3f118616d82a7e6b4dec74676c87b2a511270823 Mon Sep 17 00:00:00 2001 From: wangjiangben-hw <111729245+wangjiangben-hw@users.noreply.github.com> Date: Thu, 29 Sep 2022 20:49:28 +0800 Subject: [PATCH 16/67] Update focal_loss.py --- mmcv/ops/focal_loss.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mmcv/ops/focal_loss.py b/mmcv/ops/focal_loss.py index d4cf07138c..8132c0a1fb 100644 --- a/mmcv/ops/focal_loss.py +++ b/mmcv/ops/focal_loss.py @@ -38,8 +38,7 @@ def forward(ctx, weight: Optional[torch.Tensor] = None, reduction: str = 'mean') -> torch.Tensor: - assert isinstance( - target, (torch.Tensor, torch.LongTensor, torch.cuda.LongTensor)) + assert target.dtype == torch.long assert input.dim() == 2 assert target.dim() == 1 assert input.size(0) == target.size(0) From c81daa3604e4a9b3c2a9a6d9636f57907d4b56f3 Mon Sep 17 00:00:00 2001 From: wangjiangben-hw <111729245+wangjiangben-hw@users.noreply.github.com> Date: Fri, 30 Sep 2022 10:17:20 +0800 Subject: [PATCH 17/67] add comment --- mmcv/device/npu/distributed.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mmcv/device/npu/distributed.py b/mmcv/device/npu/distributed.py index 5f0fe55c8d..b5888a2509 100644 --- a/mmcv/device/npu/distributed.py +++ b/mmcv/device/npu/distributed.py @@ -21,6 +21,12 @@ def scatter(self, inputs, kwargs, device_ids): return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) def forward(self, *inputs, **kwargs): + # Due to the different writing methods of the model repo + # of openmmlab 1.x, the forward of DDP will be directly + # invoked in some scenarios, resulting in input not being + # moved to the device side in the npu scenario. + # We rewrote Forward to manually handle the input to the + # device side to avoid some device misalignment errors if self.device_ids: inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) return super().forward(*inputs[0], **kwargs[0]) From 1cee865587890d44fa986e83e1d41f8d781a2600 Mon Sep 17 00:00:00 2001 From: wangjiangben-hw <111729245+wangjiangben-hw@users.noreply.github.com> Date: Fri, 30 Sep 2022 10:52:41 +0800 Subject: [PATCH 18/67] clean lint --- mmcv/device/npu/distributed.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mmcv/device/npu/distributed.py b/mmcv/device/npu/distributed.py index b5888a2509..105a5e6a38 100644 --- a/mmcv/device/npu/distributed.py +++ b/mmcv/device/npu/distributed.py @@ -21,11 +21,11 @@ def scatter(self, inputs, kwargs, device_ids): return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) def forward(self, *inputs, **kwargs): - # Due to the different writing methods of the model repo - # of openmmlab 1.x, the forward of DDP will be directly - # invoked in some scenarios, resulting in input not being + # Due to the different writing methods of the model repo + # of openmmlab 1.x, the forward of DDP will be directly + # invoked in some scenarios, resulting in input not being # moved to the device side in the npu scenario. - # We rewrote Forward to manually handle the input to the + # We rewrote Forward to manually handle the input to the # device side to avoid some device misalignment errors if self.device_ids: inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) From 5841f92cfe786f82292ac3cb1254e24d70186a08 Mon Sep 17 00:00:00 2001 From: wangjiangben-hw <111729245+wangjiangben-hw@users.noreply.github.com> Date: Fri, 30 Sep 2022 10:53:29 +0800 Subject: [PATCH 19/67] update dtype assert --- mmcv/ops/focal_loss.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mmcv/ops/focal_loss.py b/mmcv/ops/focal_loss.py index 8132c0a1fb..5a941c8653 100644 --- a/mmcv/ops/focal_loss.py +++ b/mmcv/ops/focal_loss.py @@ -142,8 +142,7 @@ def forward(ctx, weight: Optional[torch.Tensor] = None, reduction='mean') -> torch.Tensor: - assert isinstance( - target, (torch.Tensor, torch.LongTensor, torch.cuda.LongTensor)) + assert target.dtype == torch.long assert input.dim() == 2 assert target.dim() == 1 assert input.size(0) == target.size(0) From 4ce938cfcd34e73083799213ec4367b55b2fca17 Mon Sep 17 00:00:00 2001 From: wangjiangben-hw <111729245+wangjiangben-hw@users.noreply.github.com> Date: Fri, 30 Sep 2022 15:43:03 +0800 Subject: [PATCH 20/67] update DDP forward and comment --- mmcv/device/npu/distributed.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/mmcv/device/npu/distributed.py b/mmcv/device/npu/distributed.py index 105a5e6a38..5e4468be5a 100644 --- a/mmcv/device/npu/distributed.py +++ b/mmcv/device/npu/distributed.py @@ -21,12 +21,13 @@ def scatter(self, inputs, kwargs, device_ids): return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) def forward(self, *inputs, **kwargs): - # Due to the different writing methods of the model repo - # of openmmlab 1.x, the forward of DDP will be directly - # invoked in some scenarios, resulting in input not being - # moved to the device side in the npu scenario. - # We rewrote Forward to manually handle the input to the - # device side to avoid some device misalignment errors + # Since the scatter method is not supported on the NPU + # and the DDP class is rewritten, when the forward of DDP + # is used, the NPU will mask the scatter branch, + # resulting in the input not being placed on the device side. + # So, forward has been rewritten here primarily to circumvent + # this situation that would cause the device misalignment. if self.device_ids: inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) - return super().forward(*inputs[0], **kwargs[0]) + return super().forward(*inputs[0], **kwargs[0]) + return super().forward(*inputs, **kwargs) From ea1d8f8a243dfe631427aec6016bf46d19f38396 Mon Sep 17 00:00:00 2001 From: wangjiangben-hw <111729245+wangjiangben-hw@users.noreply.github.com> Date: Fri, 30 Sep 2022 17:34:08 +0800 Subject: [PATCH 21/67] fix bug Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- tests/test_ops/test_focal_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_ops/test_focal_loss.py b/tests/test_ops/test_focal_loss.py index f7b012bef2..ee7c9861ae 100644 --- a/tests/test_ops/test_focal_loss.py +++ b/tests/test_ops/test_focal_loss.py @@ -150,7 +150,7 @@ def test_sigmoid_float(self, device): pytest.param( 'npu', marks=pytest.mark.skipif( - not IS_CUDA_AVAILABLE, reason='requires NPU support')), + not IS_NPU_AVAILABLE, reason='requires NPU support')), pytest.param( 'cuda', marks=pytest.mark.skipif( From 8c0945cd74fe55a1ed55403707273c66fbab73e8 Mon Sep 17 00:00:00 2001 From: ckirchhoff2021 <515629648@qq.com> Date: Sat, 8 Oct 2022 20:21:54 +0800 Subject: [PATCH 22/67] sigmoidfocalloss npu adapter bug fix --- mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp | 35 ++++++++++++++++---- 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp index bd82824689..4e0acd3c96 100644 --- a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp @@ -1,10 +1,20 @@ #include "pytorch_npu_helper.hpp" using namespace NPU_NAME_SPACE; +using namespace std; void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, Tensor output, float gamma, float alpha) { - at::Tensor target_y = at::reshape(target, input.sizes()); + int64_t n_class = input.size(1); + at::Tensor target_y = = at::ones_like(input); + if(n_class == 1) { + target_y = at::reshape(target, input.sizes()); + target_y = at::mul(target_y, -1.0); + target_y = at::add(target_y, 1.0); + } + else { + target_y = at_npu::native::NPUNativeFunctions::one_hot(target, n_class); + } target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); int64_t weight_size = weight.size(0); @@ -14,6 +24,7 @@ void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, input.sizes()); } OpCommand cmd; + string reduction = "none"; cmd.Name("SigmoidFocalLoss") .Input(input) .Input(target_y) @@ -21,7 +32,7 @@ void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, .Output(output) .Attr("gamma", gamma) .Attr("alpha", alpha) - .Attr("reduction", "none") + .Attr("reduction", reduction) .Run(); } @@ -31,7 +42,16 @@ void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, Tensor grad_input, float gamma, float alpha) { - at::Tensor target_y = at::reshape(target, input.sizes()); + int64_t n_class = input.size(1); + at::Tensor target_y = = at::ones_like(input); + if(n_class == 1) { + target_y = at::reshape(target, input.sizes()); + } + else { + target_y = at_npu::native::NPUNativeFunctions::one_hot(target, n_class); + target_y = at::mul(target_y, -1.0); + target_y = at::add(target_y, 1.0); + } target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); at::Tensor grad_up = at::ones_like(input); @@ -42,6 +62,7 @@ void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, input.sizes()); } OpCommand cmd; + string reduction = "none"; cmd.Name("SigmoidFocalLossGrad") .Input(input) .Input(target_y) @@ -50,7 +71,7 @@ void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, .Output(grad_input) .Attr("gamma", gamma) .Attr("alpha", alpha) - .Attr("reduction", "none") + .Attr("reduction", reduction) .Run(); } @@ -72,6 +93,7 @@ void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, input.sizes()); } OpCommand cmd; + string reduction = "none"; cmd.Name("SoftmaxFocalLoss") .Input(input) .Input(target_y) @@ -79,7 +101,7 @@ void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, .Output(output) .Attr("gamma", gamma) .Attr("alpha", alpha) - .Attr("reduction", "none") + .Attr("reduction", reduction) .Run(); } @@ -104,6 +126,7 @@ void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, } OpCommand cmd; + string reduction = "none"; cmd.Name("SoftmaxFocalLossGrad") .Input(input) .Input(target_y) @@ -112,7 +135,7 @@ void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, .Output(grad_input) .Attr("gamma", gamma) .Attr("alpha", alpha) - .Attr("reduction", "none") + .Attr("reduction", reduction) .Run(); } From 183e4af27a3aa188c2273c0043581cbb9d722b5f Mon Sep 17 00:00:00 2001 From: ckirchhoff2021 <515629648@qq.com> Date: Sun, 9 Oct 2022 18:50:54 +0800 Subject: [PATCH 23/67] BugFix: modify softmaxFocalLoss adapter --- mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp index 4e0acd3c96..0d00b12d42 100644 --- a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp @@ -92,17 +92,25 @@ void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, input.sizes()); } + at::Tensor op_output = at::ones_like(input); OpCommand cmd; string reduction = "none"; cmd.Name("SoftmaxFocalLoss") .Input(input) .Input(target_y) .Input(weight_y) - .Output(output) + .Output(op_output) .Attr("gamma", gamma) .Attr("alpha", alpha) .Attr("reduction", reduction) .Run(); + int64_t n_batch = input.size(0); + c10::SmallVector offsets = {0,0}; + c10::SmallVector sizes = {n_batch,1}; + at::IntArrayRef offset = at::IntArrayRef(offsets); + at::IntArrayRef size = at::IntArrayRef(sizes); + at_npu::native::NPUNativeFunctions::npu_slice_out(op_output, offset, + size, output); } void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, @@ -124,7 +132,6 @@ void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, input.sizes()); } - OpCommand cmd; string reduction = "none"; cmd.Name("SoftmaxFocalLossGrad") From 9d1376e87ad2368f21550b246c35d067f2bbcc57 Mon Sep 17 00:00:00 2001 From: ckirchhoff2021 <515629648@qq.com> Date: Sun, 9 Oct 2022 19:00:44 +0800 Subject: [PATCH 24/67] BugFix: remove equal sign in the code --- mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp index 0d00b12d42..030fa02fb6 100644 --- a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp @@ -6,7 +6,7 @@ using namespace std; void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, Tensor output, float gamma, float alpha) { int64_t n_class = input.size(1); - at::Tensor target_y = = at::ones_like(input); + at::Tensor target_y = at::ones_like(input); if(n_class == 1) { target_y = at::reshape(target, input.sizes()); target_y = at::mul(target_y, -1.0); @@ -43,7 +43,7 @@ void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, Tensor grad_input, float gamma, float alpha) { int64_t n_class = input.size(1); - at::Tensor target_y = = at::ones_like(input); + at::Tensor target_y = at::ones_like(input); if(n_class == 1) { target_y = at::reshape(target, input.sizes()); } From 33dbcde4c6164a13403d46520604d5e3a2a7453f Mon Sep 17 00:00:00 2001 From: ckirchhoff2021 <515629648@qq.com> Date: Wed, 12 Oct 2022 14:29:57 +0800 Subject: [PATCH 25/67] add npu install information in README --- .pre-commit-config.yaml | 1 + README.md | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f4dd84c0b4..2c300d103d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,6 +29,7 @@ repos: rev: v2.1.0 hooks: - id: codespell + exclude: ^README.md - repo: https://github.com/executablebooks/mdformat rev: 0.7.9 hooks: diff --git a/README.md b/README.md index 1a6541a689..57247eb8b0 100644 --- a/README.md +++ b/README.md @@ -247,6 +247,39 @@ c. Install full version with custom operators for onnxruntime If you would like to build MMCV from source, please refer to the [guide](https://mmcv.readthedocs.io/en/latest/get_started/build.html). +## NPU build and Installation + +You may want to run mmcv on your npu device, then you can build and install mmcv-npu by the following steps. + +a. Install the **ascend-toolkit** + +```python + Ascend-cann-toolkit_{version}_linux-{arch}.run +``` + +- You can download the ascend-toolkit package in https://www.hiascend.com/software/cann/community. Choose the **"Ascend-cann-toolkit\_{xxx.xxx}.run"** which fits your develop environment. +- In order to install **CANN** quickly, you can refer to the documents in https://www.hiascend.com/document/detail/zh/canncommercial/51RC2/envdeployment/instg/instg_000052.html + +b. Install the **toch_npu** + +- As the dispatch mechanism is based on torch, you have to install torch-npu before running your mmcv.ops on npu device. +- you can download the torch_npu code from https://gitee.com/ascend/pytorch, and install torch-npu as the steps in README. +- torch-npu depends on ascend-toolkit. So you have to install the ascend-toolkit, and set the ascend environment. +- ```python + source /usr/local/Ascend/ascned-toolkit/set_env.sh + ``` + +c. build and install mmcv-npu + +- ```bash + MMCV_WITH_OPS=1 FORCE_NPU=1 python setup.py build_ext + MMCV_WITH_OPS=1 FORCE_NPU=1 python setup.py develop + ``` +- or +- ```bash + MMCV_WITH_OPS=1 FORCE_NPU=1 python setup.py install + ``` + ## FAQ If you face some installation issues, CUDA related issues or RuntimeErrors, From 09dd08140e50d6618333bf483a94e1a67502e823 Mon Sep 17 00:00:00 2001 From: ckirchhoff2021 <515629648@qq.com> Date: Thu, 3 Nov 2022 15:26:04 +0800 Subject: [PATCH 26/67] add modulatedDeformConv npu adapter --- mmcv/ops/modulated_deform_conv.py | 74 +++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/mmcv/ops/modulated_deform_conv.py b/mmcv/ops/modulated_deform_conv.py index df5095f2e9..55893b814a 100644 --- a/mmcv/ops/modulated_deform_conv.py +++ b/mmcv/ops/modulated_deform_conv.py @@ -34,6 +34,72 @@ def symbolic(g, input, offset, mask, weight, bias, stride, padding, groups_i=groups, deform_groups_i=deform_groups) + @staticmethod + def _calculate_npu_sort_index(kernel_h, kernel_w, deformable_group): + split_num = deformable_group * 2 * kernel_h * kernel_w + sort_index_for_npu = list(range(split_num)) + sort_index_for_npu_fp = ( + sort_index_for_npu[1::2] + sort_index_for_npu[::2]) + sort_index_for_npu_bp_dict = { + i: idx + for idx, i in enumerate(sort_index_for_npu) + } + sort_index_for_npu_bp = [ + sort_index_for_npu_bp_dict[i] for i in sort_index_for_npu + ] + sort_index_for_npu_fp = torch.IntTensor(sort_index_for_npu_fp) + sort_index_for_npu_bp = torch.IntTensor(sort_index_for_npu_bp) + sort_index_for_npu_fp = sort_index_for_npu_fp.npu() + sort_index_for_npu_bp = sort_index_for_npu_bp.npu() + return sort_index_for_npu_fp, sort_index_for_npu_bp + + @staticmethod + def _npu_forward(ctx, input_tensor, offset, mask, weight, bias): + _, _, k1, k2 = weight.shape + conv2d_bias = bias if len(bias) > 0 else None + sort_index_for_npu_fp, sort_index_for_npu_bp = \ + ModulatedDeformConv2dFunction._calculate_npu_sort_index( + k2, k1, ctx.deform_groups) + select_offset = offset.index_select(1, sort_index_for_npu_fp) + offset_all = torch.cat([select_offset, mask], dim=1) + output, offset_out = torch.npu_deformable_conv2d( + input_tensor, + weight, + offset_all, + conv2d_bias, + kernel_size=[k2, k1], + stride=[1, 1, ctx.stride[0], ctx.stride[1]], + padding=[1, 1, ctx.padding[0], ctx.padding[1]], + dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]], + groups=ctx.groups, + deformable_groups=ctx.deform_groups, + modulated=True) + if weight.requires_grad or mask.requires_grad or offset.requires_grad \ + or input_tensor.requires_grad: + ctx.save_for_backward(input_tensor, weight, offset_out, offset_all, + sort_index_for_npu_bp) + return output + + @staticmethod + def _npu_backward(ctx, grad_output): + input_tensor, weight, offset_out, offset_all, sort_index_for_npu_bp = \ + ctx.saved_tensors + grad_input, grad_weight, grad_offset_all, grad_bias = \ + torch.npu_deformable_conv2dbk( + input_tensor, grad_output, offset_out, weight, offset_all, + kernel_size=[weight.shape[3], weight.shape[2]], + stride=[1, 1, ctx.stride[0], ctx.stride[1]], + padding=[1, 1, ctx.padding[0], ctx.padding[1]], + dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]], + groups=ctx.groups, deformable_groups=ctx.deform_groups, + modulated=True) + grad_offset = grad_offset_all.index_select(1, sort_index_for_npu_bp) + grad_mask = grad_offset_all[:, grad_offset.shape[1]:, :, :] + if not ctx.with_bias: + grad_bias = None + return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, + None, None, None, None, None, None, None, None) + @staticmethod def forward(ctx, input: torch.Tensor, @@ -56,6 +122,7 @@ def forward(ctx, ctx.groups = groups ctx.deform_groups = deform_groups ctx.with_bias = bias is not None + ctx.device = input.device.type if not ctx.with_bias: bias = input.new_empty(0) # fake tensor # When pytorch version >= 1.6.0, amp is adopted for fp16 mode; @@ -69,6 +136,10 @@ def forward(ctx, weight = weight.type_as(input) bias = bias.type_as(input) # type: ignore mask = mask.type_as(input) + if ctx.device == 'npu': + output = ModulatedDeformConv2dFunction._npu_forward( + ctx, input, offset, mask, weight, bias) + return output ctx.save_for_backward(input, offset, mask, weight, bias) output = input.new_empty( ModulatedDeformConv2dFunction._output_size(ctx, input, weight)) @@ -98,6 +169,9 @@ def forward(ctx, @staticmethod @once_differentiable def backward(ctx, grad_output: torch.Tensor) -> tuple: + if ctx.device == 'npu': + return ModulatedDeformConv2dFunction._npu_backward( + ctx, grad_output) input, offset, mask, weight, bias = ctx.saved_tensors grad_input = torch.zeros_like(input) grad_offset = torch.zeros_like(offset) From ff3ffbb191610dd61b11462df1c6b58d2d92a383 Mon Sep 17 00:00:00 2001 From: wangjiangben Date: Wed, 14 Sep 2022 12:02:27 +0800 Subject: [PATCH 27/67] init npu --- mmcv/device/npu/__init__.py | 2 +- mmcv/device/npu/data_parallel.py | 2 +- mmcv/device/npu/distributed.py | 2 +- mmcv/runner/dist_utils.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mmcv/device/npu/__init__.py b/mmcv/device/npu/__init__.py index 1a93b39678..43f982b40e 100644 --- a/mmcv/device/npu/__init__.py +++ b/mmcv/device/npu/__init__.py @@ -3,4 +3,4 @@ from .data_parallel import NPUDataParallel from .distributed import NPUDistributedDataParallel -__all__ = ['NPUDataParallel', 'NPUDistributedDataParallel'] +__all__ = ['NPUDataParallel', 'NPUDistributedDataParallel'] \ No newline at end of file diff --git a/mmcv/device/npu/data_parallel.py b/mmcv/device/npu/data_parallel.py index c107b2240e..0671f46c34 100644 --- a/mmcv/device/npu/data_parallel.py +++ b/mmcv/device/npu/data_parallel.py @@ -56,4 +56,4 @@ def __init__(self, *args, dim=0, **kwargs): torch.npu.set_device(self.src_device_obj) def scatter(self, inputs, kwargs, device_ids): - return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) + return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) \ No newline at end of file diff --git a/mmcv/device/npu/distributed.py b/mmcv/device/npu/distributed.py index 5e4468be5a..e57ba3eda6 100644 --- a/mmcv/device/npu/distributed.py +++ b/mmcv/device/npu/distributed.py @@ -30,4 +30,4 @@ def forward(self, *inputs, **kwargs): if self.device_ids: inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) return super().forward(*inputs[0], **kwargs[0]) - return super().forward(*inputs, **kwargs) + return super().forward(*inputs, **kwargs) \ No newline at end of file diff --git a/mmcv/runner/dist_utils.py b/mmcv/runner/dist_utils.py index c061b3c111..a90ab67572 100644 --- a/mmcv/runner/dist_utils.py +++ b/mmcv/runner/dist_utils.py @@ -217,4 +217,4 @@ def _allreduce_coalesced(tensors: torch.Tensor, flat_tensors.div_(world_size) for tensor, synced in zip( bucket, _unflatten_dense_tensors(flat_tensors, bucket)): - tensor.copy_(synced) + tensor.copy_(synced) \ No newline at end of file From 57e417180a2407edb6db8363d28e88c5967dcc74 Mon Sep 17 00:00:00 2001 From: ckirchhoff2021 <515629648@qq.com> Date: Mon, 19 Sep 2022 15:19:00 +0800 Subject: [PATCH 28/67] add npu extension and focal loss adapter --- mmcv/ops/csrc/common/pytorch_npu_helper.hpp | 43 ++++++++++++++++++ mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp | 48 ++++++++++++++++---- mmcv/ops/focal_loss.py | 2 +- setup.py | 22 +++++++++ 4 files changed, 105 insertions(+), 10 deletions(-) diff --git a/mmcv/ops/csrc/common/pytorch_npu_helper.hpp b/mmcv/ops/csrc/common/pytorch_npu_helper.hpp index 88607d23b3..9fcbe94f65 100644 --- a/mmcv/ops/csrc/common/pytorch_npu_helper.hpp +++ b/mmcv/ops/csrc/common/pytorch_npu_helper.hpp @@ -1,4 +1,5 @@ /****************************************************************************** +<<<<<<< HEAD * Copyright (c) 2022 Huawei Technologies Co., Ltd * All rights reserved. * @@ -14,10 +15,28 @@ * See the License for the specific language governing permissions and * limitations under the License. ******************************************************************************/ +======= +* Copyright (c) 2022 Huawei Technologies Co., Ltd +* All rights reserved. +* +* Licensed under the BSD 3-Clause License (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* https://opensource.org/licenses/BSD-3-Clause +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +******************************************************************************/ +>>>>>>> 716b3b3 (add npu extension and focal loss adapter) #ifndef PYTORCH_NPU_HELPER_HPP_ #define PYTORCH_NPU_HELPER_HPP_ +<<<<<<< HEAD #include #include #include @@ -33,3 +52,27 @@ TORCH_CHECK(x.device().type() == at::kXLA, #x " must be a NPU tensor") #endif // PYTORCH_NPU_HELPER_HPP_ +======= +#include "pytorch_cpp_helper.hpp" +#include "pytorch_device_registry.hpp" + +#ifdef MMCV_WITH_NPU +#include +#include +#include +#define NPU_NAME_SPACE at_npu::native +#define REGISTER_NPU_IMPL(key, value) REGISTER_DEVICE_IMPL(key, XLA, value) +#define CHECK_NPU(x) \ + TORCH_CHECK(x.device().type() == at::kXLA, #x " must be a NPU tensor") +#else +// for torch 1.5.0 adapter only +#include +#include +#define NPU_NAME_SPACE at::native::npu +#define REGISTER_NPU_IMPL(key, value) REGISTER_DEVICE_IMPL(key, NPU, value); +#define CHECK_NPU(x) \ + TORCH_CHECK(x.device().type() == at::kNPU, #x " must be a NPU tensor") +#endif + +#endif // PYTORCH_NPU_HELPER_HPP_ +>>>>>>> 716b3b3 (add npu extension and focal loss adapter) diff --git a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp index bd82824689..02d8002383 100644 --- a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp @@ -1,10 +1,20 @@ #include "pytorch_npu_helper.hpp" using namespace NPU_NAME_SPACE; +using namespace std; void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, Tensor output, float gamma, float alpha) { - at::Tensor target_y = at::reshape(target, input.sizes()); + int64_t n_class = input.size(1); + at::Tensor target_y = at::ones_like(input); + if(n_class == 1) { + target_y = at::reshape(target, input.sizes()); + target_y = at::mul(target_y, -1.0); + target_y = at::add(target_y, 1.0); + } + else { + target_y = at_npu::native::NPUNativeFunctions::one_hot(target, n_class); + } target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); int64_t weight_size = weight.size(0); @@ -14,6 +24,7 @@ void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, input.sizes()); } OpCommand cmd; + string reduction = "none"; cmd.Name("SigmoidFocalLoss") .Input(input) .Input(target_y) @@ -21,7 +32,7 @@ void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, .Output(output) .Attr("gamma", gamma) .Attr("alpha", alpha) - .Attr("reduction", "none") + .Attr("reduction", reduction) .Run(); } @@ -31,7 +42,16 @@ void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, Tensor grad_input, float gamma, float alpha) { - at::Tensor target_y = at::reshape(target, input.sizes()); + int64_t n_class = input.size(1); + at::Tensor target_y = at::ones_like(input); + if(n_class == 1) { + target_y = at::reshape(target, input.sizes()); + } + else { + target_y = at_npu::native::NPUNativeFunctions::one_hot(target, n_class); + target_y = at::mul(target_y, -1.0); + target_y = at::add(target_y, 1.0); + } target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); at::Tensor grad_up = at::ones_like(input); @@ -42,6 +62,7 @@ void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, input.sizes()); } OpCommand cmd; + string reduction = "none"; cmd.Name("SigmoidFocalLossGrad") .Input(input) .Input(target_y) @@ -50,7 +71,7 @@ void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, .Output(grad_input) .Attr("gamma", gamma) .Attr("alpha", alpha) - .Attr("reduction", "none") + .Attr("reduction", reduction) .Run(); } @@ -71,16 +92,25 @@ void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, input.sizes()); } + at::Tensor op_output = at::ones_like(input); OpCommand cmd; + string reduction = "none"; cmd.Name("SoftmaxFocalLoss") .Input(input) .Input(target_y) .Input(weight_y) - .Output(output) + .Output(op_output) .Attr("gamma", gamma) .Attr("alpha", alpha) - .Attr("reduction", "none") + .Attr("reduction", reduction) .Run(); + int64_t n_batch = input.size(0); + c10::SmallVector offsets = {0,0}; + c10::SmallVector sizes = {n_batch,1}; + at::IntArrayRef offset = at::IntArrayRef(offsets); + at::IntArrayRef size = at::IntArrayRef(sizes); + at_npu::native::NPUNativeFunctions::npu_slice_out(op_output, offset, + size, output); } void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, @@ -102,8 +132,8 @@ void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, input.sizes()); } - OpCommand cmd; + string reduction = "none"; cmd.Name("SoftmaxFocalLossGrad") .Input(input) .Input(target_y) @@ -112,7 +142,7 @@ void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, .Output(grad_input) .Attr("gamma", gamma) .Attr("alpha", alpha) - .Attr("reduction", "none") + .Attr("reduction", reduction) .Run(); } @@ -131,4 +161,4 @@ REGISTER_NPU_IMPL(softmax_focal_loss_forward_impl, softmax_focal_loss_forward_npu); REGISTER_NPU_IMPL(softmax_focal_loss_backward_impl, - softmax_focal_loss_backward_npu); + softmax_focal_loss_backward_npu); \ No newline at end of file diff --git a/mmcv/ops/focal_loss.py b/mmcv/ops/focal_loss.py index 5a941c8653..491123eccf 100644 --- a/mmcv/ops/focal_loss.py +++ b/mmcv/ops/focal_loss.py @@ -231,4 +231,4 @@ def __repr__(self): s += f'(gamma={self.gamma}, ' s += f'alpha={self.alpha}, ' s += f'reduction={self.reduction})' - return s + return s \ No newline at end of file diff --git a/setup.py b/setup.py index 98ba6514bc..59cbc0c637 100644 --- a/setup.py +++ b/setup.py @@ -333,16 +333,38 @@ def get_extensions(): elif (os.getenv('FORCE_NPU', '0') == '1'): print(f'Compiling {ext_name} only with CPU and NPU') try: +<<<<<<< HEAD from torch_npu.utils.cpp_extension import NpuExtension define_macros += [('MMCV_WITH_NPU', None)] extension = NpuExtension except Exception: raise ImportError('can not find any torch_npu') +======= + has_npu = torch.npu.is_available() + print('torch_npu version 1.5 is available. ', has_npu) + extension = CppExtension + except: + try: + import torch_npu + from torch_npu.utils.cpp_extension import NpuExtension + has_npu = torch_npu.npu.is_available() + print('torch_npu version 1.8 is available.: ', has_npu) + define_macros += [('MMCV_WITH_NPU', None)] + extension = NpuExtension + except: + print('can not find any torch_npu') + return extensions + +>>>>>>> 716b3b3 (add npu extension and focal loss adapter) # src op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \ glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') + \ glob.glob('./mmcv/ops/csrc/common/npu/*.cpp') + \ glob.glob('./mmcv/ops/csrc/pytorch/npu/*.cpp') +<<<<<<< HEAD +======= + +>>>>>>> 716b3b3 (add npu extension and focal loss adapter) include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common')) include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/npu')) else: From 65b5fe182bdaba231671cc69d6fa86133e925b2c Mon Sep 17 00:00:00 2001 From: ckirchhoff2021 <515629648@qq.com> Date: Tue, 20 Sep 2022 09:58:22 +0800 Subject: [PATCH 29/67] clean code --- mmcv/ops/csrc/common/pytorch_npu_helper.hpp | 45 +-------------------- setup.py | 14 +++++-- 2 files changed, 12 insertions(+), 47 deletions(-) diff --git a/mmcv/ops/csrc/common/pytorch_npu_helper.hpp b/mmcv/ops/csrc/common/pytorch_npu_helper.hpp index 9fcbe94f65..0bcb20ba5f 100644 --- a/mmcv/ops/csrc/common/pytorch_npu_helper.hpp +++ b/mmcv/ops/csrc/common/pytorch_npu_helper.hpp @@ -1,5 +1,4 @@ /****************************************************************************** -<<<<<<< HEAD * Copyright (c) 2022 Huawei Technologies Co., Ltd * All rights reserved. * @@ -15,28 +14,10 @@ * See the License for the specific language governing permissions and * limitations under the License. ******************************************************************************/ -======= -* Copyright (c) 2022 Huawei Technologies Co., Ltd -* All rights reserved. -* -* Licensed under the BSD 3-Clause License (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://opensource.org/licenses/BSD-3-Clause -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -******************************************************************************/ ->>>>>>> 716b3b3 (add npu extension and focal loss adapter) #ifndef PYTORCH_NPU_HELPER_HPP_ #define PYTORCH_NPU_HELPER_HPP_ -<<<<<<< HEAD #include #include #include @@ -51,28 +32,4 @@ #define CHECK_NPU(x) \ TORCH_CHECK(x.device().type() == at::kXLA, #x " must be a NPU tensor") -#endif // PYTORCH_NPU_HELPER_HPP_ -======= -#include "pytorch_cpp_helper.hpp" -#include "pytorch_device_registry.hpp" - -#ifdef MMCV_WITH_NPU -#include -#include -#include -#define NPU_NAME_SPACE at_npu::native -#define REGISTER_NPU_IMPL(key, value) REGISTER_DEVICE_IMPL(key, XLA, value) -#define CHECK_NPU(x) \ - TORCH_CHECK(x.device().type() == at::kXLA, #x " must be a NPU tensor") -#else -// for torch 1.5.0 adapter only -#include -#include -#define NPU_NAME_SPACE at::native::npu -#define REGISTER_NPU_IMPL(key, value) REGISTER_DEVICE_IMPL(key, NPU, value); -#define CHECK_NPU(x) \ - TORCH_CHECK(x.device().type() == at::kNPU, #x " must be a NPU tensor") -#endif - -#endif // PYTORCH_NPU_HELPER_HPP_ ->>>>>>> 716b3b3 (add npu extension and focal loss adapter) +#endif // PYTORCH_NPU_HELPER_HPP_ \ No newline at end of file diff --git a/setup.py b/setup.py index 59cbc0c637..bf3c21fdcc 100644 --- a/setup.py +++ b/setup.py @@ -343,28 +343,36 @@ def get_extensions(): has_npu = torch.npu.is_available() print('torch_npu version 1.5 is available. ', has_npu) extension = CppExtension - except: + except Exception: try: import torch_npu - from torch_npu.utils.cpp_extension import NpuExtension + from torch_npu.utils.cpp_extension import NpuExtension has_npu = torch_npu.npu.is_available() print('torch_npu version 1.8 is available.: ', has_npu) define_macros += [('MMCV_WITH_NPU', None)] extension = NpuExtension - except: + except Exception: print('can not find any torch_npu') return extensions +<<<<<<< HEAD >>>>>>> 716b3b3 (add npu extension and focal loss adapter) +======= + +>>>>>>> 6e53b3f (clean code) # src op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \ glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') + \ glob.glob('./mmcv/ops/csrc/common/npu/*.cpp') + \ glob.glob('./mmcv/ops/csrc/pytorch/npu/*.cpp') <<<<<<< HEAD +<<<<<<< HEAD ======= >>>>>>> 716b3b3 (add npu extension and focal loss adapter) +======= + +>>>>>>> 6e53b3f (clean code) include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common')) include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/npu')) else: From f38b1b2ef38ebcef4378ed90d4b87579cf11b40b Mon Sep 17 00:00:00 2001 From: ckirchhoff2021 <515629648@qq.com> Date: Tue, 20 Sep 2022 10:17:13 +0800 Subject: [PATCH 30/67] clean code --- mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp | 87 ++++++++++++++++++++ setup.py | 32 +------ 2 files changed, 88 insertions(+), 31 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp index 02d8002383..4634fa2e66 100644 --- a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp @@ -3,6 +3,10 @@ using namespace NPU_NAME_SPACE; using namespace std; +<<<<<<< HEAD +======= + +>>>>>>> 08f0a16 (clean code) void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, Tensor output, float gamma, float alpha) { int64_t n_class = input.size(1); @@ -40,6 +44,7 @@ void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, Tensor output, float gamma, float alpha); void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, +<<<<<<< HEAD Tensor grad_input, float gamma, float alpha) { int64_t n_class = input.size(1); @@ -144,6 +149,88 @@ void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, .Attr("alpha", alpha) .Attr("reduction", reduction) .Run(); +======= + Tensor grad_input, float gamma, float alpha) { + + at::Tensor target_y = at::reshape(target, input.sizes()); + target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); + at::Tensor grad_up = at::ones_like(input); + int64_t weight_size = weight.size(0); + at::Tensor weight_y = at::ones_like(input); + if(weight_size > 0) { + weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, input.sizes()); + } + + OpCommand cmd; + cmd.Name("SigmoidFocalLossGrad") + .Input(input) + .Input(target_y) + .Input(grad_up) + .Input(weight_y) + .Output(grad_input) + .Attr("gamma", gamma) + .Attr("alpha", alpha) + .Attr("reduction", "none") + .Run(); +} + +void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target, Tensor weight, + Tensor grad_input, float gamma, float alpha); + +void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, + Tensor output, float gamma, float alpha) { + + int64_t n_class = input.size(1); + at::Tensor target_y = at_npu::native::NPUNativeFunctions::one_hot(target, n_class); + target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); + at::Tensor grad_up = at::ones_like(input); + int64_t weight_size = weight.size(0); + at::Tensor weight_y = at::ones_like(input); + if(weight_size > 0) { + weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, input.sizes()); + } + + OpCommand cmd; + cmd.Name("SoftmaxFocalLoss") + .Input(input) + .Input(target_y) + .Input(grad_up) + .Input(weight_y) + .Output(grad_input) + .Attr("gamma", gamma) + .Attr("alpha", alpha) + .Attr("reduction", "none") + .Run(); +} + +void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, + Tensor grad_input, float gamma, float alpha); + +void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, Tensor buff, + Tensor grad_input, float gamma, float alpha) { + + int64_t n_class = input.size(1); + at::Tensor target_y = at_npu::native::NPUNativeFunctions::one_hot(target, n_class); + target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); + at::Tensor grad_up = at::ones_like(input); + int64_t weight_size = weight.size(0); + at::Tensor weight_y = at::ones_like(input); + if(weight_size > 0) { + weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, input.sizes()); + } + + OpCommand cmd; + cmd.Name("SoftmaxFocalLossGrad") + .Input(input) + .Input(target_y) + .Input(grad_up) + .Input(weight_y) + .Output(grad_input) + .Attr("gamma", gamma) + .Attr("alpha", alpha) + .Attr("reduction", "none") + .Run(); +>>>>>>> 08f0a16 (clean code) } void softmax_focal_loss_backward_impl(Tensor input, Tensor target, diff --git a/setup.py b/setup.py index bf3c21fdcc..67efcfed65 100644 --- a/setup.py +++ b/setup.py @@ -333,46 +333,16 @@ def get_extensions(): elif (os.getenv('FORCE_NPU', '0') == '1'): print(f'Compiling {ext_name} only with CPU and NPU') try: -<<<<<<< HEAD from torch_npu.utils.cpp_extension import NpuExtension define_macros += [('MMCV_WITH_NPU', None)] extension = NpuExtension except Exception: raise ImportError('can not find any torch_npu') -======= - has_npu = torch.npu.is_available() - print('torch_npu version 1.5 is available. ', has_npu) - extension = CppExtension - except Exception: - try: - import torch_npu - from torch_npu.utils.cpp_extension import NpuExtension - has_npu = torch_npu.npu.is_available() - print('torch_npu version 1.8 is available.: ', has_npu) - define_macros += [('MMCV_WITH_NPU', None)] - extension = NpuExtension - except Exception: - print('can not find any torch_npu') - return extensions -<<<<<<< HEAD - ->>>>>>> 716b3b3 (add npu extension and focal loss adapter) -======= - ->>>>>>> 6e53b3f (clean code) # src op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \ glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') + \ glob.glob('./mmcv/ops/csrc/common/npu/*.cpp') + \ glob.glob('./mmcv/ops/csrc/pytorch/npu/*.cpp') -<<<<<<< HEAD -<<<<<<< HEAD -======= - ->>>>>>> 716b3b3 (add npu extension and focal loss adapter) -======= - ->>>>>>> 6e53b3f (clean code) include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common')) include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/npu')) else: @@ -490,4 +460,4 @@ def get_extensions(): }, ext_modules=get_extensions(), cmdclass=cmd_class, - zip_safe=False) + zip_safe=False) \ No newline at end of file From f5c156f3e4cbcc07bc3f5958b4346fa18decdfc4 Mon Sep 17 00:00:00 2001 From: ckirchhoff2021 <515629648@qq.com> Date: Tue, 20 Sep 2022 10:47:40 +0800 Subject: [PATCH 31/67] clean code --- mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp | 87 -------------------- 1 file changed, 87 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp index 4634fa2e66..02d8002383 100644 --- a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp @@ -3,10 +3,6 @@ using namespace NPU_NAME_SPACE; using namespace std; -<<<<<<< HEAD -======= - ->>>>>>> 08f0a16 (clean code) void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, Tensor output, float gamma, float alpha) { int64_t n_class = input.size(1); @@ -44,7 +40,6 @@ void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, Tensor output, float gamma, float alpha); void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, -<<<<<<< HEAD Tensor grad_input, float gamma, float alpha) { int64_t n_class = input.size(1); @@ -149,88 +144,6 @@ void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, .Attr("alpha", alpha) .Attr("reduction", reduction) .Run(); -======= - Tensor grad_input, float gamma, float alpha) { - - at::Tensor target_y = at::reshape(target, input.sizes()); - target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); - at::Tensor grad_up = at::ones_like(input); - int64_t weight_size = weight.size(0); - at::Tensor weight_y = at::ones_like(input); - if(weight_size > 0) { - weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, input.sizes()); - } - - OpCommand cmd; - cmd.Name("SigmoidFocalLossGrad") - .Input(input) - .Input(target_y) - .Input(grad_up) - .Input(weight_y) - .Output(grad_input) - .Attr("gamma", gamma) - .Attr("alpha", alpha) - .Attr("reduction", "none") - .Run(); -} - -void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target, Tensor weight, - Tensor grad_input, float gamma, float alpha); - -void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, - Tensor output, float gamma, float alpha) { - - int64_t n_class = input.size(1); - at::Tensor target_y = at_npu::native::NPUNativeFunctions::one_hot(target, n_class); - target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); - at::Tensor grad_up = at::ones_like(input); - int64_t weight_size = weight.size(0); - at::Tensor weight_y = at::ones_like(input); - if(weight_size > 0) { - weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, input.sizes()); - } - - OpCommand cmd; - cmd.Name("SoftmaxFocalLoss") - .Input(input) - .Input(target_y) - .Input(grad_up) - .Input(weight_y) - .Output(grad_input) - .Attr("gamma", gamma) - .Attr("alpha", alpha) - .Attr("reduction", "none") - .Run(); -} - -void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, - Tensor grad_input, float gamma, float alpha); - -void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, Tensor buff, - Tensor grad_input, float gamma, float alpha) { - - int64_t n_class = input.size(1); - at::Tensor target_y = at_npu::native::NPUNativeFunctions::one_hot(target, n_class); - target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); - at::Tensor grad_up = at::ones_like(input); - int64_t weight_size = weight.size(0); - at::Tensor weight_y = at::ones_like(input); - if(weight_size > 0) { - weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, input.sizes()); - } - - OpCommand cmd; - cmd.Name("SoftmaxFocalLossGrad") - .Input(input) - .Input(target_y) - .Input(grad_up) - .Input(weight_y) - .Output(grad_input) - .Attr("gamma", gamma) - .Attr("alpha", alpha) - .Attr("reduction", "none") - .Run(); ->>>>>>> 08f0a16 (clean code) } void softmax_focal_loss_backward_impl(Tensor input, Tensor target, From c5ffc604d66c24d0d86bf523b59f08b98f7240ee Mon Sep 17 00:00:00 2001 From: ckirchhoff2021 <515629648@qq.com> Date: Thu, 3 Nov 2022 15:26:04 +0800 Subject: [PATCH 32/67] add modulatedDeformConv npu adapter --- mmcv/ops/modulated_deform_conv.py | 74 +++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/mmcv/ops/modulated_deform_conv.py b/mmcv/ops/modulated_deform_conv.py index df5095f2e9..55893b814a 100644 --- a/mmcv/ops/modulated_deform_conv.py +++ b/mmcv/ops/modulated_deform_conv.py @@ -34,6 +34,72 @@ def symbolic(g, input, offset, mask, weight, bias, stride, padding, groups_i=groups, deform_groups_i=deform_groups) + @staticmethod + def _calculate_npu_sort_index(kernel_h, kernel_w, deformable_group): + split_num = deformable_group * 2 * kernel_h * kernel_w + sort_index_for_npu = list(range(split_num)) + sort_index_for_npu_fp = ( + sort_index_for_npu[1::2] + sort_index_for_npu[::2]) + sort_index_for_npu_bp_dict = { + i: idx + for idx, i in enumerate(sort_index_for_npu) + } + sort_index_for_npu_bp = [ + sort_index_for_npu_bp_dict[i] for i in sort_index_for_npu + ] + sort_index_for_npu_fp = torch.IntTensor(sort_index_for_npu_fp) + sort_index_for_npu_bp = torch.IntTensor(sort_index_for_npu_bp) + sort_index_for_npu_fp = sort_index_for_npu_fp.npu() + sort_index_for_npu_bp = sort_index_for_npu_bp.npu() + return sort_index_for_npu_fp, sort_index_for_npu_bp + + @staticmethod + def _npu_forward(ctx, input_tensor, offset, mask, weight, bias): + _, _, k1, k2 = weight.shape + conv2d_bias = bias if len(bias) > 0 else None + sort_index_for_npu_fp, sort_index_for_npu_bp = \ + ModulatedDeformConv2dFunction._calculate_npu_sort_index( + k2, k1, ctx.deform_groups) + select_offset = offset.index_select(1, sort_index_for_npu_fp) + offset_all = torch.cat([select_offset, mask], dim=1) + output, offset_out = torch.npu_deformable_conv2d( + input_tensor, + weight, + offset_all, + conv2d_bias, + kernel_size=[k2, k1], + stride=[1, 1, ctx.stride[0], ctx.stride[1]], + padding=[1, 1, ctx.padding[0], ctx.padding[1]], + dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]], + groups=ctx.groups, + deformable_groups=ctx.deform_groups, + modulated=True) + if weight.requires_grad or mask.requires_grad or offset.requires_grad \ + or input_tensor.requires_grad: + ctx.save_for_backward(input_tensor, weight, offset_out, offset_all, + sort_index_for_npu_bp) + return output + + @staticmethod + def _npu_backward(ctx, grad_output): + input_tensor, weight, offset_out, offset_all, sort_index_for_npu_bp = \ + ctx.saved_tensors + grad_input, grad_weight, grad_offset_all, grad_bias = \ + torch.npu_deformable_conv2dbk( + input_tensor, grad_output, offset_out, weight, offset_all, + kernel_size=[weight.shape[3], weight.shape[2]], + stride=[1, 1, ctx.stride[0], ctx.stride[1]], + padding=[1, 1, ctx.padding[0], ctx.padding[1]], + dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]], + groups=ctx.groups, deformable_groups=ctx.deform_groups, + modulated=True) + grad_offset = grad_offset_all.index_select(1, sort_index_for_npu_bp) + grad_mask = grad_offset_all[:, grad_offset.shape[1]:, :, :] + if not ctx.with_bias: + grad_bias = None + return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, + None, None, None, None, None, None, None, None) + @staticmethod def forward(ctx, input: torch.Tensor, @@ -56,6 +122,7 @@ def forward(ctx, ctx.groups = groups ctx.deform_groups = deform_groups ctx.with_bias = bias is not None + ctx.device = input.device.type if not ctx.with_bias: bias = input.new_empty(0) # fake tensor # When pytorch version >= 1.6.0, amp is adopted for fp16 mode; @@ -69,6 +136,10 @@ def forward(ctx, weight = weight.type_as(input) bias = bias.type_as(input) # type: ignore mask = mask.type_as(input) + if ctx.device == 'npu': + output = ModulatedDeformConv2dFunction._npu_forward( + ctx, input, offset, mask, weight, bias) + return output ctx.save_for_backward(input, offset, mask, weight, bias) output = input.new_empty( ModulatedDeformConv2dFunction._output_size(ctx, input, weight)) @@ -98,6 +169,9 @@ def forward(ctx, @staticmethod @once_differentiable def backward(ctx, grad_output: torch.Tensor) -> tuple: + if ctx.device == 'npu': + return ModulatedDeformConv2dFunction._npu_backward( + ctx, grad_output) input, offset, mask, weight, bias = ctx.saved_tensors grad_input = torch.zeros_like(input) grad_offset = torch.zeros_like(offset) From bd3ec115ef9d6a700b241848b3fa62e4b2876807 Mon Sep 17 00:00:00 2001 From: ckirchhoff2021 <515629648@qq.com> Date: Thu, 3 Nov 2022 17:13:17 +0800 Subject: [PATCH 33/67] merge master branch 20221103 --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 9609b41f40..6feae1f9b0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -23,4 +23,4 @@ default_section = THIRDPARTY # than "BA" [codespell] quiet-level = 3 -ignore-words-list = inout,hist,ba,inh,ro,tne,warmup,warpped,warpping +ignore-words-list = inout,hist,ba,inh,ro,tne,warmup,warpped,warpping,cann From 4b9af63453333fff0e34ac6e62d2c82936412840 Mon Sep 17 00:00:00 2001 From: zcc-zjut Date: Fri, 4 Nov 2022 15:46:18 +0800 Subject: [PATCH 34/67] Add masked_ Conv2d operator in NPU --- mmcv/ops/masked_conv.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/mmcv/ops/masked_conv.py b/mmcv/ops/masked_conv.py index 6706ae9b21..82f16d4a39 100644 --- a/mmcv/ops/masked_conv.py +++ b/mmcv/ops/masked_conv.py @@ -45,6 +45,22 @@ def forward(ctx, 'Stride could not only be 1 in masked_conv2d currently.') out_channel, in_channel, kernel_h, kernel_w = weight.size() + if features.device.type == 'npu': + import torch_npu + conv = torch_npu.npu_conv2d( + features, + weight, + bias, + stride=(stride_h, stride_w), + padding=padding, + dilation=(1, 1), + groups=1) + features_h, features_w = features.size()[2:] + mask_reshape = mask.reshape(1, 1, features_h, features_w) + mask_bool = mask_reshape > 0 + output = conv * mask_bool + return output + batch_size = features.size(0) out_h = int( math.floor( From fee53c413d49b7dd9b9c270a3c30591cf837ac3b Mon Sep 17 00:00:00 2001 From: wangxiaoxin_sherie Date: Wed, 2 Nov 2022 20:19:50 +0800 Subject: [PATCH 35/67] add nms_npu --- mmcv/ops/csrc/pytorch/npu/nms_npu.cpp | 49 +++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 mmcv/ops/csrc/pytorch/npu/nms_npu.cpp diff --git a/mmcv/ops/csrc/pytorch/npu/nms_npu.cpp b/mmcv/ops/csrc/pytorch/npu/nms_npu.cpp new file mode 100644 index 0000000000..5b2b58cb77 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/nms_npu.cpp @@ -0,0 +1,49 @@ +#include "pytorch_npu_helper.hpp" + +using namespace NPU_NAME_SPACE; +using namespace std; + +Tensor nms_npu(Tensor boxes, Tensor scores, float iou_threshold, int offset){ + c10::SmallVector boxes_size ={boxes.size(0), + .size(1)}; + at::Tensor boxed_offest = + at_npu::native::OpPreparation::ApplyTensor(boxes_size, + boxes.options().dtype(at::kFloat), boxes); + at::Tensor ones_tensor = + at_npu::native::OpPreparation::ApplyTensor(boxes_size, + boxes.options().dtype(at::kFloat), boxes).fill_(1); + at::add_out(boxed_offest, boxes, ones_tensor, offset); + c10::SmallVector OneSize = {1}; + at::Tensor iou_threshold_y = + at_npu::native::OpPreparation::ApplyTensor({}, + boxes.options().dtype(at::kFloat), boxes).fill_(iou_threshold); + at::Tensor scores_threshold_y = + at_npu::native::OpPreparation::ApplyTensor({}, + boxes.options().dtype(at::kFloat), boxes).fill_(0); + at::Tensor max_outputsize_y = + at_npu::native::OpPreparation::ApplyTensor({}, + boxes.options().dtype(at::kInt), boxes).fill_(0); + c10::SmallVector outputsize = {boxes.size(0)}; + at::Tensor output = + at_npu::native::OpPreparation::ApplyTensor(outputsize, + boxes.options().dtype(at::kInt), boxes).fill_(-1); + OpCommand cmd; + cmd.Name("NonMaxSuppressionV3") + .Input(boxes) + .Input(scores) + .Input(max_outputsize_y) + .Input(iou_threshold_y) + .Input(scores_threshold_y) + .Output(output) + .Run(); + auto outputsizeBool = at::gt(output, -1); + auto outputsizeInt = outputsizeBool.to(at::ScalarType::Int); + auto countLen = at::sum(outputsizeInt, at::ScalarType::Int); + at::Tensor actual_output = output.slice(0, 0, countLen.item().toLong()); + return actual_output; +} + +Tensor nms_impl(Tensor boxes, Tensor scores, float iou_threshold, int offset); + +REGISTER_NPU_IMPL(nms_impl, + nms_npu); From 3a970b9b18c39401f43db8501a3b156b4be9495b Mon Sep 17 00:00:00 2001 From: momo609 <963372609@qq.com> Date: Sat, 5 Nov 2022 11:01:29 +0800 Subject: [PATCH 36/67] fix bug --- .pre-commit-config-zh-cn.yaml | 16 ++++++++-------- mmcv/ops/csrc/pytorch/npu/nms_npu.cpp | 13 +++++-------- mmcv/ops/focal_loss.py | 2 +- setup.py | 2 +- 4 files changed, 15 insertions(+), 18 deletions(-) diff --git a/.pre-commit-config-zh-cn.yaml b/.pre-commit-config-zh-cn.yaml index 73f0388fef..e8543d876c 100644 --- a/.pre-commit-config-zh-cn.yaml +++ b/.pre-commit-config-zh-cn.yaml @@ -62,11 +62,11 @@ repos: ^test | ^docs ) - # - repo: local - # hooks: - # - id: clang-format - # name: clang-format - # description: Format files with ClangFormat - # entry: clang-format -style=google -i - # language: system - # files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$ + - repo: local + hooks: + - id: clang-format + name: clang-format + description: Format files with ClangFormat + entry: clang-format -style=google -i + language: system + files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$ diff --git a/mmcv/ops/csrc/pytorch/npu/nms_npu.cpp b/mmcv/ops/csrc/pytorch/npu/nms_npu.cpp index 5b2b58cb77..0cd3087c31 100644 --- a/mmcv/ops/csrc/pytorch/npu/nms_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/nms_npu.cpp @@ -4,16 +4,11 @@ using namespace NPU_NAME_SPACE; using namespace std; Tensor nms_npu(Tensor boxes, Tensor scores, float iou_threshold, int offset){ - c10::SmallVector boxes_size ={boxes.size(0), - .size(1)}; at::Tensor boxed_offest = - at_npu::native::OpPreparation::ApplyTensor(boxes_size, - boxes.options().dtype(at::kFloat), boxes); + at_npu::native::OpPreparation::ApplyTensor(boxes); at::Tensor ones_tensor = - at_npu::native::OpPreparation::ApplyTensor(boxes_size, - boxes.options().dtype(at::kFloat), boxes).fill_(1); + at_npu::native::OpPreparation::ApplyTensor(boxes).fill_(1); at::add_out(boxed_offest, boxes, ones_tensor, offset); - c10::SmallVector OneSize = {1}; at::Tensor iou_threshold_y = at_npu::native::OpPreparation::ApplyTensor({}, boxes.options().dtype(at::kFloat), boxes).fill_(iou_threshold); @@ -22,7 +17,7 @@ Tensor nms_npu(Tensor boxes, Tensor scores, float iou_threshold, int offset){ boxes.options().dtype(at::kFloat), boxes).fill_(0); at::Tensor max_outputsize_y = at_npu::native::OpPreparation::ApplyTensor({}, - boxes.options().dtype(at::kInt), boxes).fill_(0); + boxes.options().dtype(at::kInt), boxes).fill_(boxes.size(0)); c10::SmallVector outputsize = {boxes.size(0)}; at::Tensor output = at_npu::native::OpPreparation::ApplyTensor(outputsize, @@ -40,6 +35,8 @@ Tensor nms_npu(Tensor boxes, Tensor scores, float iou_threshold, int offset){ auto outputsizeInt = outputsizeBool.to(at::ScalarType::Int); auto countLen = at::sum(outputsizeInt, at::ScalarType::Int); at::Tensor actual_output = output.slice(0, 0, countLen.item().toLong()); + actual_output = + at_npu::native::NPUNativeFunctions::npu_dtype_cast(actual_output, at::kLong); return actual_output; } diff --git a/mmcv/ops/focal_loss.py b/mmcv/ops/focal_loss.py index 491123eccf..5a941c8653 100644 --- a/mmcv/ops/focal_loss.py +++ b/mmcv/ops/focal_loss.py @@ -231,4 +231,4 @@ def __repr__(self): s += f'(gamma={self.gamma}, ' s += f'alpha={self.alpha}, ' s += f'reduction={self.reduction})' - return s \ No newline at end of file + return s diff --git a/setup.py b/setup.py index 67efcfed65..98ba6514bc 100644 --- a/setup.py +++ b/setup.py @@ -460,4 +460,4 @@ def get_extensions(): }, ext_modules=get_extensions(), cmdclass=cmd_class, - zip_safe=False) \ No newline at end of file + zip_safe=False) From 4647429aed898b1f38a4f38bef3c9306d7f65168 Mon Sep 17 00:00:00 2001 From: momo609 <963372609@qq.com> Date: Sat, 5 Nov 2022 17:00:40 +0800 Subject: [PATCH 37/67] fix code check --- mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp | 12 ++--- mmcv/ops/csrc/pytorch/npu/nms_npu.cpp | 56 ++++++++++---------- 2 files changed, 33 insertions(+), 35 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp index 030fa02fb6..3e46c0cec6 100644 --- a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp @@ -11,8 +11,7 @@ void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, target_y = at::reshape(target, input.sizes()); target_y = at::mul(target_y, -1.0); target_y = at::add(target_y, 1.0); - } - else { + }else { target_y = at_npu::native::NPUNativeFunctions::one_hot(target, n_class); } target_y = @@ -44,10 +43,9 @@ void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, float alpha) { int64_t n_class = input.size(1); at::Tensor target_y = at::ones_like(input); - if(n_class == 1) { + if (n_class == 1) { target_y = at::reshape(target, input.sizes()); - } - else { + }else { target_y = at_npu::native::NPUNativeFunctions::one_hot(target, n_class); target_y = at::mul(target_y, -1.0); target_y = at::add(target_y, 1.0); @@ -109,8 +107,8 @@ void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, c10::SmallVector sizes = {n_batch,1}; at::IntArrayRef offset = at::IntArrayRef(offsets); at::IntArrayRef size = at::IntArrayRef(sizes); - at_npu::native::NPUNativeFunctions::npu_slice_out(op_output, offset, - size, output); + at_npu::native::NPUNativeFunctions::npu_slice_out(op_output, offset, size, + output); } void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, diff --git a/mmcv/ops/csrc/pytorch/npu/nms_npu.cpp b/mmcv/ops/csrc/pytorch/npu/nms_npu.cpp index 0cd3087c31..a7b9edbb3e 100644 --- a/mmcv/ops/csrc/pytorch/npu/nms_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/nms_npu.cpp @@ -3,27 +3,27 @@ using namespace NPU_NAME_SPACE; using namespace std; -Tensor nms_npu(Tensor boxes, Tensor scores, float iou_threshold, int offset){ - at::Tensor boxed_offest = - at_npu::native::OpPreparation::ApplyTensor(boxes); - at::Tensor ones_tensor = - at_npu::native::OpPreparation::ApplyTensor(boxes).fill_(1); - at::add_out(boxed_offest, boxes, ones_tensor, offset); - at::Tensor iou_threshold_y = - at_npu::native::OpPreparation::ApplyTensor({}, - boxes.options().dtype(at::kFloat), boxes).fill_(iou_threshold); - at::Tensor scores_threshold_y = - at_npu::native::OpPreparation::ApplyTensor({}, - boxes.options().dtype(at::kFloat), boxes).fill_(0); - at::Tensor max_outputsize_y = - at_npu::native::OpPreparation::ApplyTensor({}, - boxes.options().dtype(at::kInt), boxes).fill_(boxes.size(0)); - c10::SmallVector outputsize = {boxes.size(0)}; - at::Tensor output = - at_npu::native::OpPreparation::ApplyTensor(outputsize, - boxes.options().dtype(at::kInt), boxes).fill_(-1); - OpCommand cmd; - cmd.Name("NonMaxSuppressionV3") +Tensor nms_npu(Tensor boxes, Tensor scores, float iou_threshold, int offset) { + at::Tensor boxed_offest = at_npu::native::OpPreparation::ApplyTensor(boxes); + at::Tensor ones_tensor = + at_npu::native::OpPreparation::ApplyTensor(boxes).fill_(1); + at::add_out(boxed_offest, boxes, ones_tensor, offset); + at::Tensor iou_threshold_y = at_npu::native::OpPreparation::ApplyTensor( + {}, boxes.options().dtype(at::kFloat), boxes) + .fill_(iou_threshold); + at::Tensor scores_threshold_y = + at_npu::native::OpPreparation::ApplyTensor( + {}, boxes.options().dtype(at::kFloat), boxes) + .fill_(0); + at::Tensor max_outputsize_y = at_npu::native::OpPreparation::ApplyTensor( + {}, boxes.options().dtype(at::kInt), boxes) + .fill_(boxes.size(0)); + c10::SmallVector outputsize = {boxes.size(0)}; + at::Tensor output = at_npu::native::OpPreparation::ApplyTensor( + outputsize, boxes.options().dtype(at::kInt), boxes) + .fill_(-1); + OpCommand cmd; + cmd.Name("NonMaxSuppressionV3") .Input(boxes) .Input(scores) .Input(max_outputsize_y) @@ -31,13 +31,13 @@ Tensor nms_npu(Tensor boxes, Tensor scores, float iou_threshold, int offset){ .Input(scores_threshold_y) .Output(output) .Run(); - auto outputsizeBool = at::gt(output, -1); - auto outputsizeInt = outputsizeBool.to(at::ScalarType::Int); - auto countLen = at::sum(outputsizeInt, at::ScalarType::Int); - at::Tensor actual_output = output.slice(0, 0, countLen.item().toLong()); - actual_output = - at_npu::native::NPUNativeFunctions::npu_dtype_cast(actual_output, at::kLong); - return actual_output; + auto outputsizeBool = at::gt(output, -1); + auto outputsizeInt = outputsizeBool.to(at::ScalarType::Int); + auto countLen = at::sum(outputsizeInt, at::ScalarType::Int); + at::Tensor actual_output = output.slice(0, 0, countLen.item().toLong()); + actual_output = at_npu::native::NPUNativeFunctions::npu_dtype_cast( + actual_output, at::kLong); + return actual_output; } Tensor nms_impl(Tensor boxes, Tensor scores, float iou_threshold, int offset); From 05d57a4b350fc2d652ebeaeda82b18c9561f6b9d Mon Sep 17 00:00:00 2001 From: momo609 <963372609@qq.com> Date: Sat, 5 Nov 2022 17:04:24 +0800 Subject: [PATCH 38/67] fix code check --- mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp index 3e46c0cec6..c949bf9539 100644 --- a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp @@ -7,11 +7,11 @@ void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, Tensor output, float gamma, float alpha) { int64_t n_class = input.size(1); at::Tensor target_y = at::ones_like(input); - if(n_class == 1) { + if (n_class == 1) { target_y = at::reshape(target, input.sizes()); target_y = at::mul(target_y, -1.0); target_y = at::add(target_y, 1.0); - }else { + } else { target_y = at_npu::native::NPUNativeFunctions::one_hot(target, n_class); } target_y = @@ -45,7 +45,7 @@ void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, at::Tensor target_y = at::ones_like(input); if (n_class == 1) { target_y = at::reshape(target, input.sizes()); - }else { + } else { target_y = at_npu::native::NPUNativeFunctions::one_hot(target, n_class); target_y = at::mul(target_y, -1.0); target_y = at::add(target_y, 1.0); @@ -103,8 +103,8 @@ void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, .Attr("reduction", reduction) .Run(); int64_t n_batch = input.size(0); - c10::SmallVector offsets = {0,0}; - c10::SmallVector sizes = {n_batch,1}; + c10::SmallVector offsets = {0, 0}; + c10::SmallVector sizes = {n_batch, 1}; at::IntArrayRef offset = at::IntArrayRef(offsets); at::IntArrayRef size = at::IntArrayRef(sizes); at_npu::native::NPUNativeFunctions::npu_slice_out(op_output, offset, size, From d0702b634f3d8d36acd32284e8f4d86080b758a0 Mon Sep 17 00:00:00 2001 From: momo609 <963372609@qq.com> Date: Sat, 5 Nov 2022 17:07:04 +0800 Subject: [PATCH 39/67] fix code check --- mmcv/ops/csrc/pytorch/npu/nms_npu.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/nms_npu.cpp b/mmcv/ops/csrc/pytorch/npu/nms_npu.cpp index a7b9edbb3e..2f86893ea7 100644 --- a/mmcv/ops/csrc/pytorch/npu/nms_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/nms_npu.cpp @@ -42,5 +42,4 @@ Tensor nms_npu(Tensor boxes, Tensor scores, float iou_threshold, int offset) { Tensor nms_impl(Tensor boxes, Tensor scores, float iou_threshold, int offset); -REGISTER_NPU_IMPL(nms_impl, - nms_npu); +REGISTER_NPU_IMPL(nms_impl, nms_npu); From aa90e8ec736d00f61b636baa2c2b11e47dfef0ef Mon Sep 17 00:00:00 2001 From: zcc-zjut Date: Mon, 7 Nov 2022 10:02:04 +0800 Subject: [PATCH 40/67] Masked_conv2d NPU --- tests/test_ops/test_masked_conv2d.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_ops/test_masked_conv2d.py b/tests/test_ops/test_masked_conv2d.py index a292f6a4fd..072b2f7f6d 100644 --- a/tests/test_ops/test_masked_conv2d.py +++ b/tests/test_ops/test_masked_conv2d.py @@ -3,7 +3,7 @@ import pytest import torch -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE class TestMaskedConv2d: @@ -16,7 +16,11 @@ class TestMaskedConv2d: pytest.param( 'mlu', marks=pytest.mark.skipif( - not IS_MLU_AVAILABLE, reason='requires MLU support')) + not IS_MLU_AVAILABLE, reason='requires MLU support')), + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')) ]) def test_masked_conv2d_all_close(self, device): from mmcv.ops import MaskedConv2d From 8dab3a3d601a75182bca09b87f61b7083be8a56f Mon Sep 17 00:00:00 2001 From: zcc-zjut Date: Mon, 7 Nov 2022 16:29:09 +0800 Subject: [PATCH 41/67] Masked_conv2d NPU --- mmcv/ops/masked_conv.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/mmcv/ops/masked_conv.py b/mmcv/ops/masked_conv.py index 82f16d4a39..a372fb9ed4 100644 --- a/mmcv/ops/masked_conv.py +++ b/mmcv/ops/masked_conv.py @@ -52,13 +52,15 @@ def forward(ctx, weight, bias, stride=(stride_h, stride_w), - padding=padding, + padding=(pad_h, pad_w), dilation=(1, 1), groups=1) - features_h, features_w = features.size()[2:] - mask_reshape = mask.reshape(1, 1, features_h, features_w) - mask_bool = mask_reshape > 0 - output = conv * mask_bool + if mask.size()[1:] != conv.size()[2:]: + raise ValueError( + 'The mask is consistent with the shape of output_conv.') + conv_h, conv_w = conv.size()[2:] + mask_reshape = mask.reshape(1, 1, conv_h, conv_w) + output = conv * mask_reshape return output batch_size = features.size(0) From 97e35bc1565f79c0c5f8fa276926319b1b8bf6dc Mon Sep 17 00:00:00 2001 From: zcc-zjut Date: Mon, 7 Nov 2022 17:48:55 +0800 Subject: [PATCH 42/67] Masked_conv2d NPU --- mmcv/ops/masked_conv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmcv/ops/masked_conv.py b/mmcv/ops/masked_conv.py index a372fb9ed4..f591b17088 100644 --- a/mmcv/ops/masked_conv.py +++ b/mmcv/ops/masked_conv.py @@ -57,7 +57,7 @@ def forward(ctx, groups=1) if mask.size()[1:] != conv.size()[2:]: raise ValueError( - 'The mask is consistent with the shape of output_conv.') + 'The mask is inconsistent with the shape of output_conv.') conv_h, conv_w = conv.size()[2:] mask_reshape = mask.reshape(1, 1, conv_h, conv_w) output = conv * mask_reshape From 2fb3bbab323737a38e86975e1917945b1837f7a1 Mon Sep 17 00:00:00 2001 From: ckirchhoff2021 <515629648@qq.com> Date: Tue, 8 Nov 2022 19:51:46 +0800 Subject: [PATCH 43/67] remove npu-install-info in README.md --- README.md | 33 --------------------------------- 1 file changed, 33 deletions(-) diff --git a/README.md b/README.md index 4105786516..7e9c02bc46 100644 --- a/README.md +++ b/README.md @@ -120,39 +120,6 @@ pip install -U openmim mim install mmcv ``` -### NPU build and Installation - -You may want to run mmcv on your npu device, then you can build and install mmcv-npu by the following steps. - -a. Install the **ascend-toolkit** - -```python - Ascend-cann-toolkit_{version}_linux-{arch}.run -``` - -- You can download the ascend-toolkit package in https://www.hiascend.com/software/cann/community. Choose the **"Ascend-cann-toolkit\_{xxx.xxx}.run"** which fits your develop environment. -- In order to install **CANN** quickly, you can refer to the documents in https://www.hiascend.com/document/detail/zh/canncommercial/51RC2/envdeployment/instg/instg_000052.html - -b. Install the **toch_npu** - -- As the dispatch mechanism is based on torch, you have to install torch-npu before running your mmcv.ops on npu device. -- you can download the torch_npu code from https://gitee.com/ascend/pytorch, and install torch-npu as the steps in README. -- torch-npu depends on ascend-toolkit. So you have to install the ascend-toolkit, and set the ascend environment. -- ```python - source /usr/local/Ascend/ascned-toolkit/set_env.sh - ``` - -c. build and install mmcv-npu - -- ```bash - MMCV_WITH_OPS=1 FORCE_NPU=1 python setup.py build_ext - MMCV_WITH_OPS=1 FORCE_NPU=1 python setup.py develop - ``` -- or -- ```bash - MMCV_WITH_OPS=1 FORCE_NPU=1 python setup.py install - ``` - ## Branch Maintenance Plan MMCV currently has two branches, the master and 2.x branches, which go through the following three phases. From f1e825db652a337419b09c927c6016c20e5fad41 Mon Sep 17 00:00:00 2001 From: ckirchhoff2021 <515629648@qq.com> Date: Tue, 8 Nov 2022 20:08:15 +0800 Subject: [PATCH 44/67] annotate the clang-format in pre-commit-config-zh-ch.yaml --- .pre-commit-config-zh-cn.yaml | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/.pre-commit-config-zh-cn.yaml b/.pre-commit-config-zh-cn.yaml index e8543d876c..73f0388fef 100644 --- a/.pre-commit-config-zh-cn.yaml +++ b/.pre-commit-config-zh-cn.yaml @@ -62,11 +62,11 @@ repos: ^test | ^docs ) - - repo: local - hooks: - - id: clang-format - name: clang-format - description: Format files with ClangFormat - entry: clang-format -style=google -i - language: system - files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$ + # - repo: local + # hooks: + # - id: clang-format + # name: clang-format + # description: Format files with ClangFormat + # entry: clang-format -style=google -i + # language: system + # files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$ From b73d8cb4b34c0769dd72a6f142615bbb8c58cf80 Mon Sep 17 00:00:00 2001 From: ckirchhoff2021 <515629648@qq.com> Date: Mon, 14 Nov 2022 20:41:38 +0800 Subject: [PATCH 45/67] Clean code: fix the clean code problem in masked_conv2d and modulated_deform_conv --- mmcv/ops/masked_conv.py | 8 +++--- mmcv/ops/modulated_deform_conv.py | 44 +++++++++++++------------------ 2 files changed, 22 insertions(+), 30 deletions(-) diff --git a/mmcv/ops/masked_conv.py b/mmcv/ops/masked_conv.py index f591b17088..e00a98b990 100644 --- a/mmcv/ops/masked_conv.py +++ b/mmcv/ops/masked_conv.py @@ -47,7 +47,7 @@ def forward(ctx, if features.device.type == 'npu': import torch_npu - conv = torch_npu.npu_conv2d( + output = torch_npu.npu_conv2d( features, weight, bias, @@ -55,12 +55,10 @@ def forward(ctx, padding=(pad_h, pad_w), dilation=(1, 1), groups=1) - if mask.size()[1:] != conv.size()[2:]: + if mask.size()[1:] != output.size()[2:]: raise ValueError( 'The mask is inconsistent with the shape of output_conv.') - conv_h, conv_w = conv.size()[2:] - mask_reshape = mask.reshape(1, 1, conv_h, conv_w) - output = conv * mask_reshape + output = output * mask return output batch_size = features.size(0) diff --git a/mmcv/ops/modulated_deform_conv.py b/mmcv/ops/modulated_deform_conv.py index 55893b814a..7970d5323e 100644 --- a/mmcv/ops/modulated_deform_conv.py +++ b/mmcv/ops/modulated_deform_conv.py @@ -35,39 +35,33 @@ def symbolic(g, input, offset, mask, weight, bias, stride, padding, deform_groups_i=deform_groups) @staticmethod - def _calculate_npu_sort_index(kernel_h, kernel_w, deformable_group): + def _calculate_sort_index(kernel_h, kernel_w, deformable_group): split_num = deformable_group * 2 * kernel_h * kernel_w - sort_index_for_npu = list(range(split_num)) - sort_index_for_npu_fp = ( - sort_index_for_npu[1::2] + sort_index_for_npu[::2]) - sort_index_for_npu_bp_dict = { - i: idx - for idx, i in enumerate(sort_index_for_npu) - } - sort_index_for_npu_bp = [ - sort_index_for_npu_bp_dict[i] for i in sort_index_for_npu - ] - sort_index_for_npu_fp = torch.IntTensor(sort_index_for_npu_fp) - sort_index_for_npu_bp = torch.IntTensor(sort_index_for_npu_bp) - sort_index_for_npu_fp = sort_index_for_npu_fp.npu() - sort_index_for_npu_bp = sort_index_for_npu_bp.npu() - return sort_index_for_npu_fp, sort_index_for_npu_bp + sort_index = list(range(split_num)) + sort_index_fp = (sort_index[1::2] + sort_index[::2]) + sort_index_bp_dict = {i: idx for idx, i in enumerate(sort_index)} + sort_index_bp = [sort_index_bp_dict[i] for i in sort_index] + sort_index_fp = torch.IntTensor(sort_index_fp) + sort_index_bp = torch.IntTensor(sort_index_bp) + sort_index_fp = sort_index_fp.npu() + sort_index_bp = sort_index_bp.npu() + return sort_index_fp, sort_index_bp @staticmethod def _npu_forward(ctx, input_tensor, offset, mask, weight, bias): - _, _, k1, k2 = weight.shape + _, _, kernel_h, kernel_w = weight.shape conv2d_bias = bias if len(bias) > 0 else None - sort_index_for_npu_fp, sort_index_for_npu_bp = \ - ModulatedDeformConv2dFunction._calculate_npu_sort_index( - k2, k1, ctx.deform_groups) - select_offset = offset.index_select(1, sort_index_for_npu_fp) + sort_index_fp, sort_index_bp = \ + ModulatedDeformConv2dFunction._calculate_sort_index( + kernel_w, kernel_h, ctx.deform_groups) + select_offset = offset.index_select(1, sort_index_fp) offset_all = torch.cat([select_offset, mask], dim=1) output, offset_out = torch.npu_deformable_conv2d( input_tensor, weight, offset_all, conv2d_bias, - kernel_size=[k2, k1], + kernel_size=[kernel_w, kernel_h], stride=[1, 1, ctx.stride[0], ctx.stride[1]], padding=[1, 1, ctx.padding[0], ctx.padding[1]], dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]], @@ -77,12 +71,12 @@ def _npu_forward(ctx, input_tensor, offset, mask, weight, bias): if weight.requires_grad or mask.requires_grad or offset.requires_grad \ or input_tensor.requires_grad: ctx.save_for_backward(input_tensor, weight, offset_out, offset_all, - sort_index_for_npu_bp) + sort_index_bp) return output @staticmethod def _npu_backward(ctx, grad_output): - input_tensor, weight, offset_out, offset_all, sort_index_for_npu_bp = \ + input_tensor, weight, offset_out, offset_all, sort_index_bp = \ ctx.saved_tensors grad_input, grad_weight, grad_offset_all, grad_bias = \ torch.npu_deformable_conv2dbk( @@ -93,7 +87,7 @@ def _npu_backward(ctx, grad_output): dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]], groups=ctx.groups, deformable_groups=ctx.deform_groups, modulated=True) - grad_offset = grad_offset_all.index_select(1, sort_index_for_npu_bp) + grad_offset = grad_offset_all.index_select(1, sort_index_bp) grad_mask = grad_offset_all[:, grad_offset.shape[1]:, :, :] if not ctx.with_bias: grad_bias = None From fad6d9644f2eeca74a2b1a8aea83fc198da66d99 Mon Sep 17 00:00:00 2001 From: jayggh <35617559+jayggh@users.noreply.github.com> Date: Fri, 18 Nov 2022 10:07:42 +0800 Subject: [PATCH 46/67] Create fused_bias_leakyrelu_npu.cpp Add NPU adapter for fused_bias_leaky_relu operator --- .../pytorch/npu/fused_bias_leakyrelu_npu.cpp | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp diff --git a/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp b/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp new file mode 100644 index 0000000000..16aaf39945 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp @@ -0,0 +1,54 @@ +#include "pytorch_npu_helper.hpp" + +using namespace NPU_NAME_SPACE; +using namespace std; + +Tensor fused_bias_leakyrelu_op_impl(Tensor& input, Tensor& bias, + Tensor& refer, int act, int grad, + float alpha, float scale) { +} + +Tensor fused_bias_leakyrelu_npu(Tensor& input,Tensor& bias, + Tensor& refer, int act, int grad, + float alpha, float scale) { + at::tensor y = at::tempty_like(input); + // forward + if (grad == 1){ + auto input_size = input.size(); + int input_length = input_size.size(); + if (input_length > 1){ + for (int i = 0; i < input_length; i++){ + if (i != 1){ + input_size[i] = 1; + } + } + } + at::Tensor bias_ = at::reshape(bias, input_size); + at::Tensor bias_tmp = NPUNativeFunctions::npu_broadcast(bias_, input.size()); + OpCommand cmd; + cmd.Name("FusedBiasLeakyRelu") + .Input(input); + .Input(bias); + .Output(y); + .Attr("sacle",sacle); + .Attr("negative_slope", alpha); + .Run(); + } + + // backward + if (grad == 2){ + OpCommand cmd; + cmd.Name("FusedBiasLeakyReluGrad") + .Input(input); + .Input(ref); + .Output(y); + .Attr("sacle",sacle); + .Attr("negative_slope", alpha); + .Run(); + } + + return y; +} + +REGISTER_NPU_IMPL(fused_bias_leakyrelu_op_impl, + fused_bias_leakyrelu_npu); From 5c59540f5bfc6b9fe172a099e96c436553e1c40e Mon Sep 17 00:00:00 2001 From: jayggh <35617559+jayggh@users.noreply.github.com> Date: Wed, 7 Dec 2022 12:53:33 +0800 Subject: [PATCH 47/67] Update fused_bias_leakyrelu_npu.cpp --- .../pytorch/npu/fused_bias_leakyrelu_npu.cpp | 55 ++++++++++--------- 1 file changed, 28 insertions(+), 27 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp b/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp index 16aaf39945..36b3c543be 100644 --- a/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp @@ -3,51 +3,52 @@ using namespace NPU_NAME_SPACE; using namespace std; -Tensor fused_bias_leakyrelu_op_impl(Tensor& input, Tensor& bias, - Tensor& refer, int act, int grad, - float alpha, float scale) { -} +Tensor fused_bias_leakyrelu_op_impl(Tensor& input, Tensor& bias, + Tensor& refer, int act, + int grad, float alpha, float scale); -Tensor fused_bias_leakyrelu_npu(Tensor& input,Tensor& bias, - Tensor& refer, int act, int grad, - float alpha, float scale) { - at::tensor y = at::tempty_like(input); - // forward +Tensor fused_bias_leakyrelu_npu(Tensor& input, Tensor& bias, + Tensor& refer, int act, + int grad, float alpha, float scale){ + at::Tensor py = at::empty_like(input); + //forward if (grad == 1){ - auto input_size = input.size(); + auto input_size = input.sizes(); int input_length = input_size.size(); + c10::SmallVector input_size_tmp; + input_size_tmp = array_to_small_vector(input_size); if (input_length > 1){ for (int i = 0; i < input_length; i++){ if (i != 1){ - input_size[i] = 1; - } + input_size_tmp[i] = 1; + } } } - at::Tensor bias_ = at::reshape(bias, input_size); - at::Tensor bias_tmp = NPUNativeFunctions::npu_broadcast(bias_, input.size()); + at::Tensor bias_tmp = at::reshape(bias, input_size_tmp); + at::Tensor bias_ = at_npu::native::NPUNativeFunctions::npu_broadcast(bias_tmp, input.sizes()); OpCommand cmd; cmd.Name("FusedBiasLeakyRelu") - .Input(input); - .Input(bias); - .Output(y); - .Attr("sacle",sacle); - .Attr("negative_slope", alpha); + .Input(input) + .Input(bias_) + .Output(py) + .Attr("scale", scale) + .Attr("negative_slope", alpha) .Run(); } - // backward + //backward if (grad == 2){ OpCommand cmd; cmd.Name("FusedBiasLeakyReluGrad") - .Input(input); - .Input(ref); - .Output(y); - .Attr("sacle",sacle); - .Attr("negative_slope", alpha); + .Input(input) + .Input(refer) + .Output(py) + .Attr("scale", scale) + .Attr("negative_slope", alpha) .Run(); } - - return y; + return py; + } REGISTER_NPU_IMPL(fused_bias_leakyrelu_op_impl, From 9c2c2d8d5bae9983f3171c2b568fed6b28e5e8ac Mon Sep 17 00:00:00 2001 From: jayggh <35617559+jayggh@users.noreply.github.com> Date: Wed, 7 Dec 2022 13:03:03 +0800 Subject: [PATCH 48/67] Update fused_bias_leakyrelu_npu.cpp --- mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp b/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp index 36b3c543be..1c3133c7e1 100644 --- a/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp @@ -3,12 +3,12 @@ using namespace NPU_NAME_SPACE; using namespace std; -Tensor fused_bias_leakyrelu_op_impl(Tensor& input, Tensor& bias, - Tensor& refer, int act, +Tensor fused_bias_leakyrelu_op_impl(Tensor& input, Tensor& bias, + Tensor& refer, int act, int grad, float alpha, float scale); -Tensor fused_bias_leakyrelu_npu(Tensor& input, Tensor& bias, - Tensor& refer, int act, +Tensor fused_bias_leakyrelu_npu(Tensor& input, Tensor& bias, + Tensor& refer, int act, int grad, float alpha, float scale){ at::Tensor py = at::empty_like(input); //forward From 4755f972209c58db7f4adcf7c767265cdeabe105 Mon Sep 17 00:00:00 2001 From: jayggh <35617559+jayggh@users.noreply.github.com> Date: Wed, 7 Dec 2022 14:40:10 +0800 Subject: [PATCH 49/67] Update ops.md --- docs/zh_cn/understand_mmcv/ops.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index 23d9b6e5fd..dda50c8ad3 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -24,7 +24,7 @@ MMCV 提供了检测、分割等任务中常用的算子 | DynamicScatter | | √ | | | | | FurthestPointSample | | √ | | | | | FurthestPointSampleWithDist | | √ | | | | -| FusedBiasLeakyrelu | | √ | | | | +| FusedBiasLeakyrelu | | √ | | | √ | | GatherPoints | | √ | | | | | GroupPoints | | √ | | | | | Iou3d | | √ | √ | | | From 2176a308efc3d4e4985e5450853d358d6c207cf8 Mon Sep 17 00:00:00 2001 From: jayggh <35617559+jayggh@users.noreply.github.com> Date: Wed, 7 Dec 2022 14:40:40 +0800 Subject: [PATCH 50/67] Update ops.md --- docs/en/understand_mmcv/ops.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index 822ee15589..d83f2c03be 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -24,7 +24,7 @@ We implement common ops used in detection, segmentation, etc. | DynamicScatter | | √ | | | | | FurthestPointSample | | √ | | | | | FurthestPointSampleWithDist | | √ | | | | -| FusedBiasLeakyrelu | | √ | | | | +| FusedBiasLeakyrelu | | √ | | | √ | | GatherPoints | | √ | | | | | GroupPoints | | √ | | | | | Iou3d | | √ | √ | | | From e1e1f08d3cba85fbdfb07521b7dd392c1c3996e0 Mon Sep 17 00:00:00 2001 From: jayggh <35617559+jayggh@users.noreply.github.com> Date: Wed, 7 Dec 2022 15:48:33 +0800 Subject: [PATCH 51/67] Update fused_bias_leakyrelu_npu.cpp --- mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp b/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp index 1c3133c7e1..eab0cfc4b4 100644 --- a/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp @@ -12,7 +12,7 @@ Tensor fused_bias_leakyrelu_npu(Tensor& input, Tensor& bias, int grad, float alpha, float scale){ at::Tensor py = at::empty_like(input); //forward - if (grad == 1){ + if (grad == 0){ auto input_size = input.sizes(); int input_length = input_size.size(); c10::SmallVector input_size_tmp; @@ -37,7 +37,7 @@ Tensor fused_bias_leakyrelu_npu(Tensor& input, Tensor& bias, } //backward - if (grad == 2){ + if (grad == 1){ OpCommand cmd; cmd.Name("FusedBiasLeakyReluGrad") .Input(input) From 0a0f0895ff3c6ffef12a77df3488dbc995dd169b Mon Sep 17 00:00:00 2001 From: jayggh <35617559+jayggh@users.noreply.github.com> Date: Wed, 7 Dec 2022 20:17:09 +0800 Subject: [PATCH 52/67] Update fused_bias_leakyrelu_npu.cpp --- mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp b/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp index eab0cfc4b4..e319030221 100644 --- a/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp @@ -3,12 +3,12 @@ using namespace NPU_NAME_SPACE; using namespace std; -Tensor fused_bias_leakyrelu_op_impl(Tensor& input, Tensor& bias, - Tensor& refer, int act, +Tensor fused_bias_leakyrelu_op_impl(const Tensor& input, const Tensor& bias, + const Tensor& refer, int act, int grad, float alpha, float scale); -Tensor fused_bias_leakyrelu_npu(Tensor& input, Tensor& bias, - Tensor& refer, int act, +Tensor fused_bias_leakyrelu_npu(const Tensor& input, const Tensor& bias, + const Tensor& refer, int act, int grad, float alpha, float scale){ at::Tensor py = at::empty_like(input); //forward From 0f382c20ad8f0ab6105f207112bdf60bc0bd3f8d Mon Sep 17 00:00:00 2001 From: jayggh <35617559+jayggh@users.noreply.github.com> Date: Thu, 8 Dec 2022 09:30:37 +0800 Subject: [PATCH 53/67] Update test_fused_bias_leakyrelu.py --- tests/test_ops/test_fused_bias_leakyrelu.py | 80 ++++++++++++++++----- 1 file changed, 61 insertions(+), 19 deletions(-) diff --git a/tests/test_ops/test_fused_bias_leakyrelu.py b/tests/test_ops/test_fused_bias_leakyrelu.py index 47357860de..73b3e057a4 100644 --- a/tests/test_ops/test_fused_bias_leakyrelu.py +++ b/tests/test_ops/test_fused_bias_leakyrelu.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import pytest import torch +from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE _USING_PARROTS = True try: @@ -14,36 +15,77 @@ class TestFusedBiasLeakyReLU: @classmethod def setup_class(cls): - if not torch.cuda.is_available(): + if not IS_CUDA_AVAILABLE and not IS_NPU_AVAILABLE: return - cls.input_tensor = torch.randn((2, 2, 2, 2), requires_grad=True).cuda() - cls.bias = torch.zeros(2, requires_grad=True).cuda() + if IS_CUDA_AVAILABLE: + cls.input_tensor = torch.randn((2, 2, 2, 2), requires_grad=True).cuda() + cls.bias = torch.zeros(2, requires_grad=True).cuda() + else: + cls.input_tensor = torch.randn((2, 2, 2, 2), requires_grad=True).npu() + cls.bias = torch.zeros(2, requires_grad=True).npu() - @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') + @pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')) + ]) def test_gradient(self): from mmcv.ops import FusedBiasLeakyReLU if _USING_PARROTS: + if IS_CUDA_AVAILABLE: + gradcheck( + FusedBiasLeakyReLU(2).cuda(), + self.input_tensor, + delta=1e-4, + pt_atol=1e-3) + else: + gradcheck( + FusedBiasLeakyReLU(2).npu(), + self.input_tensor, + delta=1e-4, + pt_atol=1e-3) + else: + if IS_CUDA_AVAILABLE: + gradcheck( + FusedBiasLeakyReLU(2).cuda(), + self.input_tensor, + eps=1e-4, + atol=1e-3) + else: + gradcheck( + FusedBiasLeakyReLU(2).npu(), + self.input_tensor, + eps=1e-4, + atol=1e-3) + + @pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')) + ]) + def test_gradgradient(self): + + from mmcv.ops import FusedBiasLeakyReLU + if IS_CUDA_AVAILABLE: gradcheck( FusedBiasLeakyReLU(2).cuda(), self.input_tensor, - delta=1e-4, - pt_atol=1e-3) + eps=1e-4, + atol=1e-3) else: gradcheck( - FusedBiasLeakyReLU(2).cuda(), + FusedBiasLeakyReLU(2).npu(), self.input_tensor, eps=1e-4, atol=1e-3) - - @pytest.mark.skipif( - not torch.cuda.is_available() or _USING_PARROTS, - reason='requires cuda') - def test_gradgradient(self): - - from mmcv.ops import FusedBiasLeakyReLU - gradgradcheck( - FusedBiasLeakyReLU(2).cuda(), - self.input_tensor, - eps=1e-4, - atol=1e-3) From 5ea830da4c72a8e100c87dd0ce9856c2ecdc5ae9 Mon Sep 17 00:00:00 2001 From: jayggh <35617559+jayggh@users.noreply.github.com> Date: Thu, 8 Dec 2022 09:35:25 +0800 Subject: [PATCH 54/67] Update fused_bias_leakyrelu.py --- mmcv/ops/fused_bias_leakyrelu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmcv/ops/fused_bias_leakyrelu.py b/mmcv/ops/fused_bias_leakyrelu.py index e23617fb3a..e562225cce 100644 --- a/mmcv/ops/fused_bias_leakyrelu.py +++ b/mmcv/ops/fused_bias_leakyrelu.py @@ -258,7 +258,7 @@ def fused_bias_leakyrelu(input: torch.Tensor, torch.Tensor: Feature map after non-linear activation. """ - if not input.is_cuda: + if not input.is_cuda and input.device.type != "npu": return bias_leakyrelu_ref(input, bias, negative_slope, scale) return FusedBiasLeakyReLUFunction.apply(input, bias.to(input.dtype), From faf4f0c78d90e52fe7eceed0b8c7ac8d3a7bf22f Mon Sep 17 00:00:00 2001 From: jayggh <35617559+jayggh@users.noreply.github.com> Date: Thu, 8 Dec 2022 09:51:28 +0800 Subject: [PATCH 55/67] Update test_fused_bias_leakyrelu.py --- tests/test_ops/test_fused_bias_leakyrelu.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_ops/test_fused_bias_leakyrelu.py b/tests/test_ops/test_fused_bias_leakyrelu.py index 73b3e057a4..af2dbeee4c 100644 --- a/tests/test_ops/test_fused_bias_leakyrelu.py +++ b/tests/test_ops/test_fused_bias_leakyrelu.py @@ -34,7 +34,7 @@ def setup_class(cls): marks=pytest.mark.skipif( not IS_NPU_AVAILABLE, reason='requires NPU support')) ]) - def test_gradient(self): + def test_gradient(self, device): from mmcv.ops import FusedBiasLeakyReLU if _USING_PARROTS: @@ -74,7 +74,7 @@ def test_gradient(self): marks=pytest.mark.skipif( not IS_NPU_AVAILABLE, reason='requires NPU support')) ]) - def test_gradgradient(self): + def test_gradgradient(self, device): from mmcv.ops import FusedBiasLeakyReLU if IS_CUDA_AVAILABLE: From 27a30609d7132e88bf8dd3cf2122f96bb0a80975 Mon Sep 17 00:00:00 2001 From: jayggh <35617559+jayggh@users.noreply.github.com> Date: Thu, 8 Dec 2022 10:26:09 +0800 Subject: [PATCH 56/67] Update fused_bias_leakyrelu.py --- mmcv/ops/fused_bias_leakyrelu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmcv/ops/fused_bias_leakyrelu.py b/mmcv/ops/fused_bias_leakyrelu.py index e562225cce..fe17d2db7b 100644 --- a/mmcv/ops/fused_bias_leakyrelu.py +++ b/mmcv/ops/fused_bias_leakyrelu.py @@ -258,7 +258,7 @@ def fused_bias_leakyrelu(input: torch.Tensor, torch.Tensor: Feature map after non-linear activation. """ - if not input.is_cuda and input.device.type != "npu": + if not input.is_cuda and input.device.type != 'npu': return bias_leakyrelu_ref(input, bias, negative_slope, scale) return FusedBiasLeakyReLUFunction.apply(input, bias.to(input.dtype), From 0ec0363a75b2d14a74891addb113283f91b36bf5 Mon Sep 17 00:00:00 2001 From: jayggh <35617559+jayggh@users.noreply.github.com> Date: Thu, 8 Dec 2022 10:39:37 +0800 Subject: [PATCH 57/67] Update test_fused_bias_leakyrelu.py --- tests/test_ops/test_fused_bias_leakyrelu.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/test_ops/test_fused_bias_leakyrelu.py b/tests/test_ops/test_fused_bias_leakyrelu.py index af2dbeee4c..21ce3b072e 100644 --- a/tests/test_ops/test_fused_bias_leakyrelu.py +++ b/tests/test_ops/test_fused_bias_leakyrelu.py @@ -7,7 +7,7 @@ try: from parrots.autograd import gradcheck except ImportError: - from torch.autograd import gradcheck, gradgradcheck + from torch.autograd import gradcheck _USING_PARROTS = False @@ -18,10 +18,12 @@ def setup_class(cls): if not IS_CUDA_AVAILABLE and not IS_NPU_AVAILABLE: return if IS_CUDA_AVAILABLE: - cls.input_tensor = torch.randn((2, 2, 2, 2), requires_grad=True).cuda() + cls.input_tensor = torch.randn((2, 2, 2, 2), + requires_grad=True).cuda() cls.bias = torch.zeros(2, requires_grad=True).cuda() else: - cls.input_tensor = torch.randn((2, 2, 2, 2), requires_grad=True).npu() + cls.input_tensor = torch.randn((2, 2, 2, 2), + requires_grad=True).npu() cls.bias = torch.zeros(2, requires_grad=True).npu() @pytest.mark.parametrize('device', [ From 0869a6020cce4cafc0103d58367f4ca66555d73c Mon Sep 17 00:00:00 2001 From: jayggh <35617559+jayggh@users.noreply.github.com> Date: Thu, 8 Dec 2022 10:40:54 +0800 Subject: [PATCH 58/67] Update ops.md --- docs/en/understand_mmcv/ops.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index 773f681bd1..d83f2c03be 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -59,4 +59,4 @@ We implement common ops used in detection, segmentation, etc. | TINShift | | √ | √ | | | | UpFirDn2d | | √ | | | | | Voxelization | √ | √ | | | | -| PrRoIPool | | √ | | | | \ No newline at end of file +| PrRoIPool | | √ | | | | From 5477510212ad1b7c8d6abfd6417ac0795a16e010 Mon Sep 17 00:00:00 2001 From: jayggh <1439725485@qq.com> Date: Thu, 8 Dec 2022 20:29:51 +0800 Subject: [PATCH 59/67] amend for CI --- .../pytorch/npu/fused_bias_leakyrelu_npu.cpp | 43 +++++++++++-------- tests/test_ops/test_fused_bias_leakyrelu.py | 1 + 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp b/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp index e319030221..8c9ac2f869 100644 --- a/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp @@ -3,29 +3,35 @@ using namespace NPU_NAME_SPACE; using namespace std; -Tensor fused_bias_leakyrelu_op_impl(const Tensor& input, const Tensor& bias, - const Tensor& refer, int act, - int grad, float alpha, float scale); +Tensor fused_bias_leakyrelu_op_impl(const Tensor &input, const Tensor &bias, + const Tensor &refer, int act, int grad, + float alpha, float scale); -Tensor fused_bias_leakyrelu_npu(const Tensor& input, const Tensor& bias, - const Tensor& refer, int act, - int grad, float alpha, float scale){ +Tensor fused_bias_leakyrelu_npu(const Tensor &input, const Tensor &bias, + const Tensor &refer, int act, int grad, + float alpha, float scale) +{ at::Tensor py = at::empty_like(input); - //forward - if (grad == 0){ + // forward + if (grad == 0) + { auto input_size = input.sizes(); int input_length = input_size.size(); c10::SmallVector input_size_tmp; input_size_tmp = array_to_small_vector(input_size); - if (input_length > 1){ - for (int i = 0; i < input_length; i++){ - if (i != 1){ - input_size_tmp[i] = 1; - } + if (input_length > 1) + { + for (int i = 0; i < input_length; i++) + { + if (i != 1) + { + input_size_tmp[i] = 1; } + } } at::Tensor bias_tmp = at::reshape(bias, input_size_tmp); - at::Tensor bias_ = at_npu::native::NPUNativeFunctions::npu_broadcast(bias_tmp, input.sizes()); + at::Tensor bias_ = at_npu::native::NPUNativeFunctions::npu_broadcast( + bias_tmp, input.sizes()); OpCommand cmd; cmd.Name("FusedBiasLeakyRelu") .Input(input) @@ -36,8 +42,9 @@ Tensor fused_bias_leakyrelu_npu(const Tensor& input, const Tensor& bias, .Run(); } - //backward - if (grad == 1){ + // backward + if (grad == 1) + { OpCommand cmd; cmd.Name("FusedBiasLeakyReluGrad") .Input(input) @@ -48,8 +55,6 @@ Tensor fused_bias_leakyrelu_npu(const Tensor& input, const Tensor& bias, .Run(); } return py; - } -REGISTER_NPU_IMPL(fused_bias_leakyrelu_op_impl, - fused_bias_leakyrelu_npu); +REGISTER_NPU_IMPL(fused_bias_leakyrelu_op_impl, fused_bias_leakyrelu_npu); diff --git a/tests/test_ops/test_fused_bias_leakyrelu.py b/tests/test_ops/test_fused_bias_leakyrelu.py index 21ce3b072e..75b233042e 100644 --- a/tests/test_ops/test_fused_bias_leakyrelu.py +++ b/tests/test_ops/test_fused_bias_leakyrelu.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import pytest import torch + from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE _USING_PARROTS = True From 15ac0ebef4d75ed613721e08078befbae2d2c3a5 Mon Sep 17 00:00:00 2001 From: jayggh <1439725485@qq.com> Date: Thu, 8 Dec 2022 21:11:33 +0800 Subject: [PATCH 60/67] bugfix --- tests/test_ops/test_fused_bias_leakyrelu.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_ops/test_fused_bias_leakyrelu.py b/tests/test_ops/test_fused_bias_leakyrelu.py index 75b233042e..98129df9c2 100644 --- a/tests/test_ops/test_fused_bias_leakyrelu.py +++ b/tests/test_ops/test_fused_bias_leakyrelu.py @@ -22,7 +22,7 @@ def setup_class(cls): cls.input_tensor = torch.randn((2, 2, 2, 2), requires_grad=True).cuda() cls.bias = torch.zeros(2, requires_grad=True).cuda() - else: + elif IS_NPU_AVAILABLE: cls.input_tensor = torch.randn((2, 2, 2, 2), requires_grad=True).npu() cls.bias = torch.zeros(2, requires_grad=True).npu() @@ -47,7 +47,7 @@ def test_gradient(self, device): self.input_tensor, delta=1e-4, pt_atol=1e-3) - else: + elif IS_NPU_AVAILABLE: gradcheck( FusedBiasLeakyReLU(2).npu(), self.input_tensor, @@ -60,7 +60,7 @@ def test_gradient(self, device): self.input_tensor, eps=1e-4, atol=1e-3) - else: + elif IS_NPU_AVAILABLE: gradcheck( FusedBiasLeakyReLU(2).npu(), self.input_tensor, @@ -86,7 +86,7 @@ def test_gradgradient(self, device): self.input_tensor, eps=1e-4, atol=1e-3) - else: + elif IS_NPU_AVAILABLE: gradcheck( FusedBiasLeakyReLU(2).npu(), self.input_tensor, From 780e3a16753e259e02de86b5c7bf3427d49f8390 Mon Sep 17 00:00:00 2001 From: jayggh <1439725485@qq.com> Date: Fri, 9 Dec 2022 09:51:33 +0800 Subject: [PATCH 61/67] amend ops.md --- docs/en/understand_mmcv/ops.md | 116 +++++++++++++++--------------- docs/zh_cn/understand_mmcv/ops.md | 116 +++++++++++++++--------------- 2 files changed, 116 insertions(+), 116 deletions(-) diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index d83f2c03be..cfc70e7734 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -2,61 +2,61 @@ We implement common ops used in detection, segmentation, etc. -| Device | CPU | CUDA | MLU | MPS | NPU | -| ---------------------------- | --- | ---- | --- | --- | --- | -| ActiveRotatedFilter | √ | √ | | | | -| AssignScoreWithK | | √ | | | | -| BallQuery | | √ | | | | -| BBoxOverlaps | | √ | √ | √ | | -| BorderAlign | | √ | | | | -| BoxIouRotated | √ | √ | | | | -| BoxIouQuadri | √ | √ | | | | -| CARAFE | | √ | √ | | | -| ChamferDistance | | √ | | | | -| CrissCrossAttention | | √ | | | | -| ContourExpand | √ | | | | | -| ConvexIoU | | √ | | | | -| CornerPool | | √ | | | | -| Correlation | | √ | | | | -| Deformable Convolution v1/v2 | √ | √ | | | | -| Deformable RoIPool | | √ | √ | | | -| DiffIoURotated | | √ | | | | -| DynamicScatter | | √ | | | | -| FurthestPointSample | | √ | | | | -| FurthestPointSampleWithDist | | √ | | | | -| FusedBiasLeakyrelu | | √ | | | √ | -| GatherPoints | | √ | | | | -| GroupPoints | | √ | | | | -| Iou3d | | √ | √ | | | -| KNN | | √ | | | | -| MaskedConv | | √ | √ | | √ | -| MergeCells | | √ | | | | -| MinAreaPolygon | | √ | | | | -| ModulatedDeformConv2d | √ | √ | √ | | √ | -| MultiScaleDeformableAttn | | √ | √ | | | -| NMS | √ | √ | √ | | √ | -| NMSRotated | √ | √ | | | | -| NMSQuadri | √ | √ | | | | -| PixelGroup | √ | | | | | -| PointsInBoxes | √ | √ | | | | -| PointsInPolygons | | √ | | | | -| PSAMask | √ | √ | √ | | | -| RotatedFeatureAlign | √ | √ | | | | -| RoIPointPool3d | | √ | √ | | | -| RoIPool | | √ | √ | | | -| RoIAlignRotated | √ | √ | √ | | | -| RiRoIAlignRotated | | √ | | | | -| RoIAlign | √ | √ | √ | | | -| RoIAwarePool3d | | √ | √ | | | -| SAConv2d | | √ | | | | -| SigmoidFocalLoss | | √ | √ | | √ | -| SoftmaxFocalLoss | | √ | | | √ | -| SoftNMS | | √ | | | | -| Sparse Convolution | | √ | | | | -| Synchronized BatchNorm | | √ | | | | -| ThreeInterpolate | | √ | | | | -| ThreeNN | | √ | √ | | | -| TINShift | | √ | √ | | | -| UpFirDn2d | | √ | | | | -| Voxelization | √ | √ | | | | -| PrRoIPool | | √ | | | | +| Device | CPU | CUDA | MLU | MPS | Ascend | +| ---------------------------- | --- | ---- | --- | --- | ------ | +| ActiveRotatedFilter | √ | √ | | | | +| AssignScoreWithK | | √ | | | | +| BallQuery | | √ | | | | +| BBoxOverlaps | | √ | √ | √ | | +| BorderAlign | | √ | | | | +| BoxIouRotated | √ | √ | | | | +| BoxIouQuadri | √ | √ | | | | +| CARAFE | | √ | √ | | | +| ChamferDistance | | √ | | | | +| CrissCrossAttention | | √ | | | | +| ContourExpand | √ | | | | | +| ConvexIoU | | √ | | | | +| CornerPool | | √ | | | | +| Correlation | | √ | | | | +| Deformable Convolution v1/v2 | √ | √ | | | | +| Deformable RoIPool | | √ | √ | | | +| DiffIoURotated | | √ | | | | +| DynamicScatter | | √ | | | | +| FurthestPointSample | | √ | | | | +| FurthestPointSampleWithDist | | √ | | | | +| FusedBiasLeakyrelu | | √ | | | √ | +| GatherPoints | | √ | | | | +| GroupPoints | | √ | | | | +| Iou3d | | √ | √ | | | +| KNN | | √ | | | | +| MaskedConv | | √ | √ | | √ | +| MergeCells | | √ | | | | +| MinAreaPolygon | | √ | | | | +| ModulatedDeformConv2d | √ | √ | √ | | √ | +| MultiScaleDeformableAttn | | √ | √ | | | +| NMS | √ | √ | √ | | √ | +| NMSRotated | √ | √ | | | | +| NMSQuadri | √ | √ | | | | +| PixelGroup | √ | | | | | +| PointsInBoxes | √ | √ | | | | +| PointsInPolygons | | √ | | | | +| PSAMask | √ | √ | √ | | | +| RotatedFeatureAlign | √ | √ | | | | +| RoIPointPool3d | | √ | √ | | | +| RoIPool | | √ | √ | | | +| RoIAlignRotated | √ | √ | √ | | | +| RiRoIAlignRotated | | √ | | | | +| RoIAlign | √ | √ | √ | | | +| RoIAwarePool3d | | √ | √ | | | +| SAConv2d | | √ | | | | +| SigmoidFocalLoss | | √ | √ | | √ | +| SoftmaxFocalLoss | | √ | | | √ | +| SoftNMS | | √ | | | | +| Sparse Convolution | | √ | | | | +| Synchronized BatchNorm | | √ | | | | +| ThreeInterpolate | | √ | | | | +| ThreeNN | | √ | √ | | | +| TINShift | | √ | √ | | | +| UpFirDn2d | | √ | | | | +| Voxelization | √ | √ | | | | +| PrRoIPool | | √ | | | | diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index dda50c8ad3..7fbba87689 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -2,61 +2,61 @@ MMCV 提供了检测、分割等任务中常用的算子 -| Device | CPU | CUDA | MLU | MPS | NPU | -| ---------------------------- | --- | ---- | --- | --- | --- | -| ActiveRotatedFilter | √ | √ | | | | -| AssignScoreWithK | | √ | | | | -| BallQuery | | √ | | | | -| BBoxOverlaps | | √ | √ | √ | | -| BorderAlign | | √ | | | | -| BoxIouRotated | √ | √ | | | | -| BoxIouQuadri | √ | √ | | | | -| CARAFE | | √ | √ | | | -| ChamferDistance | | √ | | | | -| CrissCrossAttention | | √ | | | | -| ContourExpand | √ | | | | | -| ConvexIoU | | √ | | | | -| CornerPool | | √ | | | | -| Correlation | | √ | | | | -| Deformable Convolution v1/v2 | √ | √ | | | | -| Deformable RoIPool | | √ | √ | | | -| DiffIoURotated | | √ | | | | -| DynamicScatter | | √ | | | | -| FurthestPointSample | | √ | | | | -| FurthestPointSampleWithDist | | √ | | | | -| FusedBiasLeakyrelu | | √ | | | √ | -| GatherPoints | | √ | | | | -| GroupPoints | | √ | | | | -| Iou3d | | √ | √ | | | -| KNN | | √ | | | | -| MaskedConv | | √ | √ | | √ | -| MergeCells | | √ | | | | -| MinAreaPolygon | | √ | | | | -| ModulatedDeformConv2d | √ | √ | √ | | √ | -| MultiScaleDeformableAttn | | √ | √ | | | -| NMS | √ | √ | √ | | √ | -| NMSRotated | √ | √ | | | | -| NMSQuadri | √ | √ | | | | -| PixelGroup | √ | | | | | -| PointsInBoxes | √ | √ | | | | -| PointsInPolygons | | √ | | | | -| PSAMask | √ | √ | √ | | | -| RotatedFeatureAlign | √ | √ | | | | -| RoIPointPool3d | | √ | √ | | | -| RoIPool | | √ | √ | | | -| RoIAlignRotated | √ | √ | √ | | | -| RiRoIAlignRotated | | √ | | | | -| RoIAlign | √ | √ | √ | | | -| RoIAwarePool3d | | √ | √ | | | -| SAConv2d | | √ | | | | -| SigmoidFocalLoss | | √ | √ | | √ | -| SoftmaxFocalLoss | | √ | | | √ | -| SoftNMS | | √ | | | | -| Sparse Convolution | | √ | | | | -| Synchronized BatchNorm | | √ | | | | -| ThreeInterpolate | | √ | | | | -| ThreeNN | | √ | √ | | | -| TINShift | | √ | √ | | | -| UpFirDn2d | | √ | | | | -| Voxelization | √ | √ | | | | -| PrRoIPool | | √ | | | | +| Device | CPU | CUDA | MLU | MPS | Ascend | +| ---------------------------- | --- | ---- | --- | --- | ------ | +| ActiveRotatedFilter | √ | √ | | | | +| AssignScoreWithK | | √ | | | | +| BallQuery | | √ | | | | +| BBoxOverlaps | | √ | √ | √ | | +| BorderAlign | | √ | | | | +| BoxIouRotated | √ | √ | | | | +| BoxIouQuadri | √ | √ | | | | +| CARAFE | | √ | √ | | | +| ChamferDistance | | √ | | | | +| CrissCrossAttention | | √ | | | | +| ContourExpand | √ | | | | | +| ConvexIoU | | √ | | | | +| CornerPool | | √ | | | | +| Correlation | | √ | | | | +| Deformable Convolution v1/v2 | √ | √ | | | | +| Deformable RoIPool | | √ | √ | | | +| DiffIoURotated | | √ | | | | +| DynamicScatter | | √ | | | | +| FurthestPointSample | | √ | | | | +| FurthestPointSampleWithDist | | √ | | | | +| FusedBiasLeakyrelu | | √ | | | √ | +| GatherPoints | | √ | | | | +| GroupPoints | | √ | | | | +| Iou3d | | √ | √ | | | +| KNN | | √ | | | | +| MaskedConv | | √ | √ | | √ | +| MergeCells | | √ | | | | +| MinAreaPolygon | | √ | | | | +| ModulatedDeformConv2d | √ | √ | √ | | √ | +| MultiScaleDeformableAttn | | √ | √ | | | +| NMS | √ | √ | √ | | √ | +| NMSRotated | √ | √ | | | | +| NMSQuadri | √ | √ | | | | +| PixelGroup | √ | | | | | +| PointsInBoxes | √ | √ | | | | +| PointsInPolygons | | √ | | | | +| PSAMask | √ | √ | √ | | | +| RotatedFeatureAlign | √ | √ | | | | +| RoIPointPool3d | | √ | √ | | | +| RoIPool | | √ | √ | | | +| RoIAlignRotated | √ | √ | √ | | | +| RiRoIAlignRotated | | √ | | | | +| RoIAlign | √ | √ | √ | | | +| RoIAwarePool3d | | √ | √ | | | +| SAConv2d | | √ | | | | +| SigmoidFocalLoss | | √ | √ | | √ | +| SoftmaxFocalLoss | | √ | | | √ | +| SoftNMS | | √ | | | | +| Sparse Convolution | | √ | | | | +| Synchronized BatchNorm | | √ | | | | +| ThreeInterpolate | | √ | | | | +| ThreeNN | | √ | √ | | | +| TINShift | | √ | √ | | | +| UpFirDn2d | | √ | | | | +| Voxelization | √ | √ | | | | +| PrRoIPool | | √ | | | | From 7bb7049b1e48301082bc2b734cfc6a3c979bdbe3 Mon Sep 17 00:00:00 2001 From: jayggh <35617559+jayggh@users.noreply.github.com> Date: Mon, 12 Dec 2022 08:47:16 +0800 Subject: [PATCH 62/67] Update test_fused_bias_leakyrelu.py --- tests/test_ops/test_fused_bias_leakyrelu.py | 34 +++++---------------- 1 file changed, 7 insertions(+), 27 deletions(-) diff --git a/tests/test_ops/test_fused_bias_leakyrelu.py b/tests/test_ops/test_fused_bias_leakyrelu.py index 98129df9c2..076fd512a5 100644 --- a/tests/test_ops/test_fused_bias_leakyrelu.py +++ b/tests/test_ops/test_fused_bias_leakyrelu.py @@ -47,25 +47,12 @@ def test_gradient(self, device): self.input_tensor, delta=1e-4, pt_atol=1e-3) - elif IS_NPU_AVAILABLE: - gradcheck( - FusedBiasLeakyReLU(2).npu(), - self.input_tensor, - delta=1e-4, - pt_atol=1e-3) else: - if IS_CUDA_AVAILABLE: - gradcheck( - FusedBiasLeakyReLU(2).cuda(), - self.input_tensor, - eps=1e-4, - atol=1e-3) - elif IS_NPU_AVAILABLE: - gradcheck( - FusedBiasLeakyReLU(2).npu(), - self.input_tensor, - eps=1e-4, - atol=1e-3) + gradcheck( + FusedBiasLeakyReLU(2).to(device), + self.input_tensor, + eps=1e-4, + atol=1e-3) @pytest.mark.parametrize('device', [ pytest.param( @@ -80,15 +67,8 @@ def test_gradient(self, device): def test_gradgradient(self, device): from mmcv.ops import FusedBiasLeakyReLU - if IS_CUDA_AVAILABLE: - gradcheck( - FusedBiasLeakyReLU(2).cuda(), - self.input_tensor, - eps=1e-4, - atol=1e-3) - elif IS_NPU_AVAILABLE: - gradcheck( - FusedBiasLeakyReLU(2).npu(), + gradcheck( + FusedBiasLeakyReLU(2).to(device), self.input_tensor, eps=1e-4, atol=1e-3) From f60c9164307e48617593d5cda6de4974614df284 Mon Sep 17 00:00:00 2001 From: jayggh <1439725485@qq.com> Date: Mon, 12 Dec 2022 11:44:11 +0800 Subject: [PATCH 63/67] clean code --- tests/test_ops/test_fused_bias_leakyrelu.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_ops/test_fused_bias_leakyrelu.py b/tests/test_ops/test_fused_bias_leakyrelu.py index 076fd512a5..5b963d718c 100644 --- a/tests/test_ops/test_fused_bias_leakyrelu.py +++ b/tests/test_ops/test_fused_bias_leakyrelu.py @@ -68,7 +68,7 @@ def test_gradgradient(self, device): from mmcv.ops import FusedBiasLeakyReLU gradcheck( - FusedBiasLeakyReLU(2).to(device), - self.input_tensor, - eps=1e-4, - atol=1e-3) + FusedBiasLeakyReLU(2).to(device), + self.input_tensor, + eps=1e-4, + atol=1e-3) From 37b0d85468d4addc5bf6a44f76641aedbcd0c826 Mon Sep 17 00:00:00 2001 From: jayggh <1439725485@qq.com> Date: Mon, 12 Dec 2022 14:33:23 +0800 Subject: [PATCH 64/67] bugfix --- tests/test_ops/test_fused_bias_leakyrelu.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_ops/test_fused_bias_leakyrelu.py b/tests/test_ops/test_fused_bias_leakyrelu.py index 5b963d718c..e6f6fb9f75 100644 --- a/tests/test_ops/test_fused_bias_leakyrelu.py +++ b/tests/test_ops/test_fused_bias_leakyrelu.py @@ -8,7 +8,7 @@ try: from parrots.autograd import gradcheck except ImportError: - from torch.autograd import gradcheck + from torch.autograd import gradcheck, gradgradcheck _USING_PARROTS = False @@ -67,7 +67,7 @@ def test_gradient(self, device): def test_gradgradient(self, device): from mmcv.ops import FusedBiasLeakyReLU - gradcheck( + gradgradcheck( FusedBiasLeakyReLU(2).to(device), self.input_tensor, eps=1e-4, From 57754775586758d337ebb26483b1f5c187f41fd8 Mon Sep 17 00:00:00 2001 From: jayggh <1439725485@qq.com> Date: Mon, 12 Dec 2022 15:38:05 +0800 Subject: [PATCH 65/67] clean code --- .../pytorch/npu/fused_bias_leakyrelu_npu.cpp | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp b/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp index 8c9ac2f869..651518be68 100644 --- a/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp @@ -9,22 +9,17 @@ Tensor fused_bias_leakyrelu_op_impl(const Tensor &input, const Tensor &bias, Tensor fused_bias_leakyrelu_npu(const Tensor &input, const Tensor &bias, const Tensor &refer, int act, int grad, - float alpha, float scale) -{ + float alpha, float scale){ at::Tensor py = at::empty_like(input); // forward - if (grad == 0) - { + if (grad == 0){ auto input_size = input.sizes(); int input_length = input_size.size(); c10::SmallVector input_size_tmp; input_size_tmp = array_to_small_vector(input_size); - if (input_length > 1) - { - for (int i = 0; i < input_length; i++) - { - if (i != 1) - { + if (input_length > 1){ + for (int i = 0; i < input_length; i++){ + if (i != 1){ input_size_tmp[i] = 1; } } @@ -43,8 +38,7 @@ Tensor fused_bias_leakyrelu_npu(const Tensor &input, const Tensor &bias, } // backward - if (grad == 1) - { + if (grad == 1){ OpCommand cmd; cmd.Name("FusedBiasLeakyReluGrad") .Input(input) From 2b33fa15238679cd5453f92d4959c8078297d29c Mon Sep 17 00:00:00 2001 From: jayggh <35617559+jayggh@users.noreply.github.com> Date: Mon, 12 Dec 2022 19:50:24 +0800 Subject: [PATCH 66/67] Update fused_bias_leakyrelu_npu.cpp --- mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp b/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp index 651518be68..cc4f4b37c0 100644 --- a/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp @@ -12,14 +12,14 @@ Tensor fused_bias_leakyrelu_npu(const Tensor &input, const Tensor &bias, float alpha, float scale){ at::Tensor py = at::empty_like(input); // forward - if (grad == 0){ + if (grad == 0) { auto input_size = input.sizes(); int input_length = input_size.size(); c10::SmallVector input_size_tmp; input_size_tmp = array_to_small_vector(input_size); - if (input_length > 1){ - for (int i = 0; i < input_length; i++){ - if (i != 1){ + if (input_length > 1) { + for (int i = 0; i < input_length; i++) { + if (i != 1) { input_size_tmp[i] = 1; } } @@ -38,7 +38,7 @@ Tensor fused_bias_leakyrelu_npu(const Tensor &input, const Tensor &bias, } // backward - if (grad == 1){ + if (grad == 1) { OpCommand cmd; cmd.Name("FusedBiasLeakyReluGrad") .Input(input) From 15a3c764a3762c87fb4671f291be809c7dad8557 Mon Sep 17 00:00:00 2001 From: jayggh <35617559+jayggh@users.noreply.github.com> Date: Mon, 12 Dec 2022 19:50:43 +0800 Subject: [PATCH 67/67] Update fused_bias_leakyrelu_npu.cpp --- mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp b/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp index cc4f4b37c0..cd052b5868 100644 --- a/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp @@ -9,7 +9,7 @@ Tensor fused_bias_leakyrelu_op_impl(const Tensor &input, const Tensor &bias, Tensor fused_bias_leakyrelu_npu(const Tensor &input, const Tensor &bias, const Tensor &refer, int act, int grad, - float alpha, float scale){ + float alpha, float scale) { at::Tensor py = at::empty_like(input); // forward if (grad == 0) {