Skip to content

Commit

Permalink
[Enhance] Add AMP support for MLU_DCNv2 (open-mmlab#2548)
Browse files Browse the repository at this point in the history
  • Loading branch information
mengpenghui authored and root committed Jan 30, 2023
1 parent e691c56 commit 1360bf4
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
5 changes: 4 additions & 1 deletion mmcv/ops/modulated_deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,10 +406,13 @@ def forward(self, x):
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
x = x.type_as(offset)
weight = self.weight.type_as(x)
mask = mask.type_as(x)
return tv_deform_conv2d(
x,
offset,
self.weight,
weight,
bias=self.bias,
stride=self.stride,
padding=self.padding,
Expand Down
19 changes: 12 additions & 7 deletions tests/test_ops/test_modulated_deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _test_mdconv(self, dtype=torch.float, device='cuda'):
assert numpy.allclose(dcn.conv_offset.bias.grad.cpu().detach().numpy(),
dcn_offset_b_grad, 1e-2)

def _test_amp_mdconv(self, input_dtype=torch.float):
def _test_amp_mdconv(self, input_dtype=torch.float, device='cuda'):
"""The function to test amp released on pytorch 1.6.0.
The type of input data might be torch.float or torch.half,
Expand All @@ -84,10 +84,15 @@ def _test_amp_mdconv(self, input_dtype=torch.float):
Args:
input_dtype: torch.float or torch.half.
"""
if not torch.cuda.is_available():
if not torch.cuda.is_available() and device == 'cuda':
return
from mmcv.ops import ModulatedDeformConv2dPack
input = torch.tensor(input_t).cuda().type(input_dtype)
if device == 'mlu':
from mmcv.ops import \
ModulatedDeformConv2dPack_MLU as ModulatedDeformConv2dPack
else:
from mmcv.ops import ModulatedDeformConv2dPack

input = torch.tensor(input_t).to(device).type(input_dtype)
input.requires_grad = True

dcn = ModulatedDeformConv2dPack(
Expand All @@ -97,7 +102,7 @@ def _test_amp_mdconv(self, input_dtype=torch.float):
stride=1,
padding=1,
deform_groups=1,
bias=False).cuda()
bias=False).to(device)
dcn.weight.data.fill_(1.)
output = dcn(input)
output.sum().backward()
Expand Down Expand Up @@ -126,5 +131,5 @@ def test_mdconv(self):
if (TORCH_VERSION != 'parrots'
and digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
with autocast(enabled=True):
self._test_amp_mdconv(torch.float)
self._test_amp_mdconv(torch.half)
self._test_amp_mdconv(torch.float, device=device)
self._test_amp_mdconv(torch.half, device=device)

0 comments on commit 1360bf4

Please sign in to comment.