Skip to content

Commit

Permalink
[Fix] Check the version of torchvision in __init__ of DCN (open-mmlab…
Browse files Browse the repository at this point in the history
  • Loading branch information
mengpenghui authored and ClowDragon committed Oct 24, 2023
1 parent bc9c3ee commit bd64735
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 9 deletions.
10 changes: 4 additions & 6 deletions mmcv/ops/deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,12 +410,9 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,

if IS_MLU_AVAILABLE:
import torchvision
from torchvision.ops import deform_conv2d as tv_deform_conv2d

from mmcv.utils import digit_version
assert digit_version(torchvision.__version__) >= digit_version(
'0.10.0a0'), 'the version of torchvision should be >= 0.10.0'

from torchvision.ops import deform_conv2d as tv_deform_conv2d

@CONV_LAYERS.register_module('DCN', force=True)
class DeformConv2dPack_MLU(DeformConv2d):
Expand Down Expand Up @@ -443,6 +440,8 @@ class DeformConv2dPack_MLU(DeformConv2d):
"""

def __init__(self, *args, **kwargs):
assert digit_version(torchvision.__version__) >= digit_version(
'0.10.0a0'), 'the version of torchvision should be >= 0.10.0'
super().__init__(*args, **kwargs)

self.conv_offset = nn.Conv2d(
Expand All @@ -466,7 +465,6 @@ def forward(self, x: Tensor) -> Tensor: # type: ignore
) == 0, 'batch size must be divisible by im2col_step'
offset = self.conv_offset(x)
x = x.type_as(offset)
weight = self.weight
weight = weight.type_as(x)
weight = self.weight.type_as(x)
return tv_deform_conv2d(x, offset, weight, None, self.stride,
self.padding, self.dilation)
6 changes: 3 additions & 3 deletions mmcv/ops/modulated_deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,11 +356,9 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,

if IS_MLU_AVAILABLE:
import torchvision
from torchvision.ops import deform_conv2d as tv_deform_conv2d

from mmcv.utils import digit_version
assert digit_version(torchvision.__version__) >= digit_version(
'0.10.0a0'), 'the version of torchvision should be >= 0.10.0'
from torchvision.ops import deform_conv2d as tv_deform_conv2d

@CONV_LAYERS.register_module('DCNv2', force=True)
class ModulatedDeformConv2dPack_MLU(ModulatedDeformConv2d):
Expand All @@ -383,6 +381,8 @@ class ModulatedDeformConv2dPack_MLU(ModulatedDeformConv2d):
"""

def __init__(self, *args, **kwargs):
assert digit_version(torchvision.__version__) >= digit_version(
'0.10.0a0'), 'the version of torchvision should be >= 0.10.0'
super().__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d(
self.in_channels,
Expand Down

0 comments on commit bd64735

Please sign in to comment.