Skip to content

Commit

Permalink
Merge branch 'npu-dev' of https://github.com/jayggh/mmcv into npu-dev
Browse files Browse the repository at this point in the history
  • Loading branch information
jayggh committed Dec 12, 2022
2 parents 1f88e77 + 7bb7049 commit eb4a194
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 24 deletions.
2 changes: 1 addition & 1 deletion docs/en/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ We implement common ops used in detection, segmentation, etc.
| DynamicScatter | || | | |
| FurthestPointSample | || | | |
| FurthestPointSampleWithDist | || | | |
| FusedBiasLeakyrelu | || | | |
| FusedBiasLeakyrelu | || | | |
| GatherPoints | || | | |
| GroupPoints | || | | |
| Iou3d | ||| | |
Expand Down
2 changes: 1 addition & 1 deletion docs/zh_cn/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| DynamicScatter | || | | |
| FurthestPointSample | || | | |
| FurthestPointSampleWithDist | || | | |
| FusedBiasLeakyrelu | || | | |
| FusedBiasLeakyrelu | || | | |
| GatherPoints | || | | |
| GroupPoints | || | | |
| Iou3d | ||| | |
Expand Down
60 changes: 60 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;
using namespace std;

Tensor fused_bias_leakyrelu_op_impl(const Tensor &input, const Tensor &bias,
const Tensor &refer, int act, int grad,
float alpha, float scale);

Tensor fused_bias_leakyrelu_npu(const Tensor &input, const Tensor &bias,
const Tensor &refer, int act, int grad,
float alpha, float scale)
{
at::Tensor py = at::empty_like(input);
// forward
if (grad == 0)
{
auto input_size = input.sizes();
int input_length = input_size.size();
c10::SmallVector<int64_t, SIZE> input_size_tmp;
input_size_tmp = array_to_small_vector(input_size);
if (input_length > 1)
{
for (int i = 0; i < input_length; i++)
{
if (i != 1)
{
input_size_tmp[i] = 1;
}
}
}
at::Tensor bias_tmp = at::reshape(bias, input_size_tmp);
at::Tensor bias_ = at_npu::native::NPUNativeFunctions::npu_broadcast(
bias_tmp, input.sizes());
OpCommand cmd;
cmd.Name("FusedBiasLeakyRelu")
.Input(input)
.Input(bias_)
.Output(py)
.Attr("scale", scale)
.Attr("negative_slope", alpha)
.Run();
}

// backward
if (grad == 1)
{
OpCommand cmd;
cmd.Name("FusedBiasLeakyReluGrad")
.Input(input)
.Input(refer)
.Output(py)
.Attr("scale", scale)
.Attr("negative_slope", alpha)
.Run();
}
return py;
}

REGISTER_NPU_IMPL(fused_bias_leakyrelu_op_impl, fused_bias_leakyrelu_npu);
2 changes: 1 addition & 1 deletion mmcv/ops/fused_bias_leakyrelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def fused_bias_leakyrelu(input: torch.Tensor,
torch.Tensor: Feature map after non-linear activation.
"""

if not input.is_cuda:
if not input.is_cuda and input.device.type != 'npu':
return bias_leakyrelu_ref(input, bias, negative_slope, scale)

return FusedBiasLeakyReLUFunction.apply(input, bias.to(input.dtype),
Expand Down
67 changes: 46 additions & 21 deletions tests/test_ops/test_fused_bias_leakyrelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,48 +2,73 @@
import pytest
import torch

from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE

_USING_PARROTS = True
try:
from parrots.autograd import gradcheck
except ImportError:
from torch.autograd import gradcheck, gradgradcheck
from torch.autograd import gradcheck
_USING_PARROTS = False


class TestFusedBiasLeakyReLU:

@classmethod
def setup_class(cls):
if not torch.cuda.is_available():
if not IS_CUDA_AVAILABLE and not IS_NPU_AVAILABLE:
return
cls.input_tensor = torch.randn((2, 2, 2, 2), requires_grad=True).cuda()
cls.bias = torch.zeros(2, requires_grad=True).cuda()
if IS_CUDA_AVAILABLE:
cls.input_tensor = torch.randn((2, 2, 2, 2),
requires_grad=True).cuda()
cls.bias = torch.zeros(2, requires_grad=True).cuda()
elif IS_NPU_AVAILABLE:
cls.input_tensor = torch.randn((2, 2, 2, 2),
requires_grad=True).npu()
cls.bias = torch.zeros(2, requires_grad=True).npu()

@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_gradient(self):
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
def test_gradient(self, device):

from mmcv.ops import FusedBiasLeakyReLU
if _USING_PARROTS:
gradcheck(
FusedBiasLeakyReLU(2).cuda(),
self.input_tensor,
delta=1e-4,
pt_atol=1e-3)
if IS_CUDA_AVAILABLE:
gradcheck(
FusedBiasLeakyReLU(2).cuda(),
self.input_tensor,
delta=1e-4,
pt_atol=1e-3)
else:
gradcheck(
FusedBiasLeakyReLU(2).cuda(),
FusedBiasLeakyReLU(2).to(device),
self.input_tensor,
eps=1e-4,
atol=1e-3)

@pytest.mark.skipif(
not torch.cuda.is_available() or _USING_PARROTS,
reason='requires cuda')
def test_gradgradient(self):
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
def test_gradgradient(self, device):

from mmcv.ops import FusedBiasLeakyReLU
gradgradcheck(
FusedBiasLeakyReLU(2).cuda(),
self.input_tensor,
eps=1e-4,
atol=1e-3)
gradcheck(
FusedBiasLeakyReLU(2).to(device),
self.input_tensor,
eps=1e-4,
atol=1e-3)

0 comments on commit eb4a194

Please sign in to comment.