Skip to content

Commit

Permalink
[Fix] Fix a dilation bug of MLU-DCNv2 and add limitation of torchvisi…
Browse files Browse the repository at this point in the history
…on (#2519)
  • Loading branch information
mengpenghui authored Jan 12, 2023
1 parent 2810718 commit c9d477b
Showing 1 changed file with 13 additions and 35 deletions.
48 changes: 13 additions & 35 deletions mmcv/ops/modulated_deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,10 +355,15 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,


if IS_MLU_AVAILABLE:
import torchvision

from mmcv.utils import digit_version
assert digit_version(torchvision.__version__) >= digit_version(
'0.10.0a0'), 'the version of torchvision should be >= 0.10.0'
from torchvision.ops import deform_conv2d as tv_deform_conv2d

@CONV_LAYERS.register_module('DCNv2', force=True)
class ModulatedDeformConv2dPack_MLU(nn.modules.Module):
class ModulatedDeformConv2dPack_MLU(ModulatedDeformConv2d):
"""This class is the DCNv2 implementation of the MLU device. The MLU
backend support of the operator has been implemented in torchvision.
The mmcv registration mechanism is used for multiplexing here. The
Expand All @@ -377,51 +382,24 @@ class ModulatedDeformConv2dPack_MLU(nn.modules.Module):
otherwise False.
"""

def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int]],
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
deform_groups: int = 1,
bias: Union[bool, str] = True):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.groups = groups
self.deform_groups = deform_groups
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels, *self.kernel_size))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d(
self.in_channels,
self.deform_groups * 3 * self.kernel_size[0] *
self.kernel_size[1],
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
bias=True)
self.init_weights()

def init_weights(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.zero_()
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()
super().init_weights()
if hasattr(self, 'conv_offset'):
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()

def forward(self, x):
out = self.conv_offset(x)
Expand Down

0 comments on commit c9d477b

Please sign in to comment.