diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index 6a0241412b..6215716824 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -18,7 +18,7 @@ We implement common ops used in detection, segmentation, etc. | ConvexIoU | | √ | | | | | CornerPool | | √ | | | | | Correlation | | √ | | | | -| Deformable Convolution v1/v2 | √ | √ | | | √ | +| Deformable Convolution v1/v2 | √ | √ | √ | | √ | | Deformable RoIPool | | √ | √ | | √ | | DiffIoURotated | | √ | | | | | DynamicScatter | | √ | | | | diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index deeb60eede..e8f96a8be0 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -18,7 +18,7 @@ MMCV 提供了检测、分割等任务中常用的算子 | ConvexIoU | | √ | | | | | CornerPool | | √ | | | | | Correlation | | √ | | | | -| Deformable Convolution v1/v2 | √ | √ | | | √ | +| Deformable Convolution v1/v2 | √ | √ | √ | | √ | | Deformable RoIPool | | √ | √ | | √ | | DiffIoURotated | | √ | | | | | DynamicScatter | | √ | | | | diff --git a/mmcv/ops/__init__.py b/mmcv/ops/__init__.py index bcb9a5a4da..bdad553736 100755 --- a/mmcv/ops/__init__.py +++ b/mmcv/ops/__init__.py @@ -109,6 +109,7 @@ ] if IS_MLU_AVAILABLE: + from .deform_conv import DeformConv2dPack_MLU # noqa:F401 from .modulated_deform_conv import \ ModulatedDeformConv2dPack_MLU # noqa:F401 - __all__.append('ModulatedDeformConv2dPack_MLU') + __all__.extend(['ModulatedDeformConv2dPack_MLU', 'DeformConv2dPack_MLU']) diff --git a/mmcv/ops/deform_conv.py b/mmcv/ops/deform_conv.py index dcc0abb6e7..50112d31fc 100644 --- a/mmcv/ops/deform_conv.py +++ b/mmcv/ops/deform_conv.py @@ -9,7 +9,7 @@ from torch.autograd.function import once_differentiable from torch.nn.modules.utils import _pair, _single -from mmcv.utils import deprecated_api_warning +from mmcv.utils import IS_MLU_AVAILABLE, deprecated_api_warning from ..cnn import CONV_LAYERS from ..utils import ext_loader, print_log from .modulated_deform_conv import ModulatedDeformConv2dFunction @@ -434,3 +434,67 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + +if IS_MLU_AVAILABLE: + import torchvision + + from mmcv.utils import digit_version + assert digit_version(torchvision.__version__) >= digit_version( + '0.10.0a0'), 'the version of torchvision should be >= 0.10.0' + + from torchvision.ops import deform_conv2d as tv_deform_conv2d + + @CONV_LAYERS.register_module('DCN', force=True) + class DeformConv2dPack_MLU(DeformConv2d): + """This class is the DCN implementation of the MLU device. The MLU + backend support of the operator has been implemented in torchvision. + The mmcv registration mechanism is used for multiplexing here. The + torchvision implementation of DCN is called. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int): Same as nn.Conv2d, while tuple is not supported. + padding (int): Same as nn.Conv2d, while tuple is not supported. + dilation (int): Same as nn.Conv2d, while tuple is not supported. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by + the norm_cfg. Bias will be set as True if norm_cfg is None, + otherwise False. + im2col_step (int): Number of samples processed by + im2col_cuda_kernel per call. It will work when ``batch_size`` + > ``im2col_step``, but ``batch_size`` must be divisible by + ``im2col_step``. Default: 32. `New in version 1.7.2. + Currently not supported on MLU devices.` + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deform_groups * 2 * self.kernel_size[0] * + self.kernel_size[1], + kernel_size=self.kernel_size, + stride=_pair(self.stride), + padding=_pair(self.padding), + dilation=_pair(self.dilation), + bias=True) + self.init_offset() + + def init_offset(self): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x: Tensor) -> Tensor: # type: ignore + cur_im2col_step = min(self.im2col_step, x.size(0)) + assert (x.size(0) % cur_im2col_step + ) == 0, 'batch size must be divisible by im2col_step' + offset = self.conv_offset(x) + x = x.type_as(offset) + weight = self.weight + weight = weight.type_as(x) + return tv_deform_conv2d(x, offset, weight, None, self.stride, + self.padding, self.dilation) diff --git a/tests/test_ops/test_deform_conv.py b/tests/test_ops/test_deform_conv.py index e77b5f9753..89c27137bd 100644 --- a/tests/test_ops/test_deform_conv.py +++ b/tests/test_ops/test_deform_conv.py @@ -3,7 +3,7 @@ import pytest import torch -from mmcv.utils import TORCH_VERSION, digit_version +from mmcv.utils import IS_MLU_AVAILABLE, TORCH_VERSION, digit_version try: # If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast @@ -45,7 +45,10 @@ def _test_deformconv(self, im2col_step=2): if not torch.cuda.is_available() and device == 'cuda': pytest.skip('test requires GPU') - from mmcv.ops import DeformConv2dPack + if device == 'mlu': + from mmcv.ops import DeformConv2dPack_MLU as DeformConv2dPack + else: + from mmcv.ops import DeformConv2dPack c_in = 1 c_out = 1 batch_size = 10 @@ -69,6 +72,8 @@ def _test_deformconv(self, torch.Tensor(deform_weight).reshape(1, 1, 2, 2)) if device == 'cuda': model.cuda() + elif device == 'mlu': + model.mlu() model.type(dtype) out = model(x) @@ -108,6 +113,7 @@ def _test_deformconv(self, def _test_amp_deformconv(self, input_dtype, threshold=1e-3, + device='cuda', batch_size=10, im2col_step=2): """The function to test amp released on pytorch 1.6.0. @@ -120,15 +126,18 @@ def _test_amp_deformconv(self, input_dtype: torch.float or torch.half. threshold: the same as above function. """ - if not torch.cuda.is_available(): + if not torch.cuda.is_available() and device == 'cuda': return - from mmcv.ops import DeformConv2dPack + if device == 'mlu': + from mmcv.ops import DeformConv2dPack_MLU as DeformConv2dPack + else: + from mmcv.ops import DeformConv2dPack c_in = 1 c_out = 1 repeated_input = np.repeat(input, batch_size, axis=0) repeated_gt_out = np.repeat(gt_out, batch_size, axis=0) repeated_gt_x_grad = np.repeat(gt_x_grad, batch_size, axis=0) - x = torch.Tensor(repeated_input).cuda().type(input_dtype) + x = torch.Tensor(repeated_input).to(device).type(input_dtype) x.requires_grad = True model = DeformConv2dPack( in_channels=c_in, @@ -143,7 +152,10 @@ def _test_amp_deformconv(self, torch.Tensor(offset_bias).reshape(8)) model.weight.data = torch.nn.Parameter( torch.Tensor(deform_weight).reshape(1, 1, 2, 2)) - model.cuda() + if device == 'cuda': + model.cuda() + elif device == 'mlu': + model.mlu() out = model(x) out.backward(torch.ones_like(out)) @@ -180,21 +192,25 @@ def _test_amp_deformconv(self, def test_deformconv(self): self._test_deformconv(torch.double, device='cpu') self._test_deformconv(torch.float, device='cpu', threshold=1e-1) - self._test_deformconv(torch.double) - self._test_deformconv(torch.float) - self._test_deformconv(torch.half, threshold=1e-1) + + device = 'mlu' if IS_MLU_AVAILABLE else 'cuda' + self._test_deformconv(torch.double, device=device) + self._test_deformconv(torch.float, device=device) + self._test_deformconv(torch.half, threshold=1e-1, device=device) # test batch_size < im2col_step - self._test_deformconv(torch.float, batch_size=1, im2col_step=2) + self._test_deformconv( + torch.float, batch_size=1, im2col_step=2, device=device) # test bach_size % im2col_step != 0 with pytest.raises( AssertionError, match='batch size must be divisible by im2col_step'): - self._test_deformconv(torch.float, batch_size=10, im2col_step=3) + self._test_deformconv( + torch.float, batch_size=10, im2col_step=3, device=device) # test amp when torch version >= '1.6.0', the type of # input data for deformconv might be torch.float or torch.half if (TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) >= digit_version('1.6.0')): with autocast(enabled=True): - self._test_amp_deformconv(torch.float, 1e-1) - self._test_amp_deformconv(torch.half, 1e-1) + self._test_amp_deformconv(torch.float, 1e-1, device) + self._test_amp_deformconv(torch.half, 1e-1, device)