From 4bd3b5027aee4e3bfde9e429a59f0eab3d341c3e Mon Sep 17 00:00:00 2001 From: Guangchen Lin <347630870@qq.com> Date: Sun, 23 May 2021 17:42:59 +0800 Subject: [PATCH] [Fix] Support amp (pytorch >= 1.6.0) on DCN and DCNv2/ Add unit tests on DCN/DCNv2 amp (#1029) * fix fp16 bug on DCNv2 * support fp16 on DCN/DCNv2 when pytorch >= '1.6.0' * add comment * Modified the comments * Unified the usages of '.to()' and '.type_as()' --- mmcv/ops/deform_conv.py | 10 ++- mmcv/ops/modulated_deform_conv.py | 9 +++ tests/test_ops/test_deform_conv.py | 70 ++++++++++++++++++++ tests/test_ops/test_modulated_deform_conv.py | 54 +++++++++++++++ 4 files changed, 141 insertions(+), 2 deletions(-) diff --git a/mmcv/ops/deform_conv.py b/mmcv/ops/deform_conv.py index 5282e26193..04666f58db 100644 --- a/mmcv/ops/deform_conv.py +++ b/mmcv/ops/deform_conv.py @@ -70,8 +70,14 @@ def forward(ctx, ctx.deform_groups = deform_groups ctx.im2col_step = im2col_step - # until the code is modified for torch.cuda.amp.autocast, - # we need to cast weight to avoid type mismatch in fp16 training + # When pytorch version >= 1.6.0, amp is adopted for fp16 mode; + # amp won't cast the type of model (float32), but "offset" is cast + # to float16 by nn.Conv2d automatically, leading to the type + # mismatch with input (when it is float32) or weight. + # The flag for whether to use fp16 or amp is the type of "offset", + # we cast weight and input to temporarily support fp16 and amp + # whatever the pytorch version is. + input = input.type_as(offset) weight = weight.type_as(input) ctx.save_for_backward(input, offset, weight) diff --git a/mmcv/ops/modulated_deform_conv.py b/mmcv/ops/modulated_deform_conv.py index b8ff1adeb2..b3dfd0b003 100644 --- a/mmcv/ops/modulated_deform_conv.py +++ b/mmcv/ops/modulated_deform_conv.py @@ -57,6 +57,15 @@ def forward(ctx, ctx.with_bias = bias is not None if not ctx.with_bias: bias = input.new_empty(0) # fake tensor + # When pytorch version >= 1.6.0, amp is adopted for fp16 mode; + # amp won't cast the type of model (float32), but "offset" is cast + # to float16 by nn.Conv2d automatically, leading to the type + # mismatch with input (when it is float32) or weight. + # The flag for whether to use fp16 or amp is the type of "offset", + # we cast weight and input to temporarily support fp16 and amp + # whatever the pytorch version is. + input = input.type_as(offset) + weight = weight.type_as(input) ctx.save_for_backward(input, offset, mask, weight, bias) output = input.new_empty( ModulatedDeformConv2dFunction._output_size(ctx, input, weight)) diff --git a/tests/test_ops/test_deform_conv.py b/tests/test_ops/test_deform_conv.py index b99df8d011..c49d47980a 100644 --- a/tests/test_ops/test_deform_conv.py +++ b/tests/test_ops/test_deform_conv.py @@ -2,6 +2,15 @@ import pytest import torch +from mmcv.utils import TORCH_VERSION + +try: + # If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast + # would be imported and used; we should test if our modules support it. + from torch.cuda.amp import autocast +except ImportError: + pass + input = [[[[1., 2., 3.], [0., 1., 2.], [3., 5., 2.]]]] offset_weight = [[[0.1, 0.4, 0.6, 0.1]], [[0.3, 0.2, 0.1, 0.3]], [[0.5, 0.5, 0.2, 0.8]], [[0.8, 0.3, 0.9, 0.1]], @@ -71,7 +80,68 @@ def _test_deformconv(self, dtype=torch.float, threshold=1e-3): with pytest.raises(AssertionError): model = DeformConv2d(3, 4, 3, groups=3) + def _test_amp_deformconv(self, input_dtype, threshold=1e-3): + """The function to test amp released on pytorch 1.6.0. + + The type of input data might be torch.float or torch.half, + so we should test deform_conv in both cases. With amp, the + data type of model will NOT be set manually. + + Args: + input_dtype: torch.float or torch.half. + threshold: the same as above function. + """ + if not torch.cuda.is_available(): + return + from mmcv.ops import DeformConv2dPack + c_in = 1 + c_out = 1 + x = torch.Tensor(input).cuda().type(input_dtype) + x.requires_grad = True + model = DeformConv2dPack(c_in, c_out, 2, stride=1, padding=0) + model.conv_offset.weight.data = torch.nn.Parameter( + torch.Tensor(offset_weight).reshape(8, 1, 2, 2)) + model.conv_offset.bias.data = torch.nn.Parameter( + torch.Tensor(offset_bias).reshape(8)) + model.weight.data = torch.nn.Parameter( + torch.Tensor(deform_weight).reshape(1, 1, 2, 2)) + model.cuda() + + out = model(x) + out.backward(torch.ones_like(out)) + + assert np.allclose(out.data.detach().cpu().numpy(), gt_out, threshold) + assert np.allclose(x.grad.detach().cpu().numpy(), gt_x_grad, threshold) + assert np.allclose( + model.conv_offset.weight.grad.detach().cpu().numpy(), + gt_offset_weight_grad, threshold) + assert np.allclose(model.conv_offset.bias.grad.detach().cpu().numpy(), + gt_offset_bias_grad, threshold) + assert np.allclose(model.weight.grad.detach().cpu().numpy(), + gt_deform_weight_grad, threshold) + + from mmcv.ops import DeformConv2d + # test bias + model = DeformConv2d(1, 1, 2, stride=1, padding=0) + assert not hasattr(model, 'bias') + # test bias=True + with pytest.raises(AssertionError): + model = DeformConv2d(1, 1, 2, stride=1, padding=0, bias=True) + # test in_channels % group != 0 + with pytest.raises(AssertionError): + model = DeformConv2d(3, 2, 3, groups=2) + # test out_channels % group != 0 + with pytest.raises(AssertionError): + model = DeformConv2d(3, 4, 3, groups=3) + def test_deformconv(self): self._test_deformconv(torch.double) self._test_deformconv(torch.float) self._test_deformconv(torch.half, 1e-1) + + # test amp when torch version >= '1.6.0', the type of + # input data for deformconv might be torch.float or torch.half + if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0': + with autocast(enabled=True): + self._test_amp_deformconv(torch.float, 1e-1) + self._test_amp_deformconv(torch.half, 1e-1) diff --git a/tests/test_ops/test_modulated_deform_conv.py b/tests/test_ops/test_modulated_deform_conv.py index 43ddd66707..83c6f8a405 100644 --- a/tests/test_ops/test_modulated_deform_conv.py +++ b/tests/test_ops/test_modulated_deform_conv.py @@ -3,6 +3,15 @@ import numpy import torch +from mmcv.utils import TORCH_VERSION + +try: + # If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast + # would be imported and used; we should test if our modules support it. + from torch.cuda.amp import autocast +except ImportError: + pass + cur_dir = os.path.dirname(os.path.abspath(__file__)) input_t = [[[[1., 2., 3.], [1., 2., 3.], [1., 2., 3.]]]] @@ -58,7 +67,52 @@ def _test_mdconv(self, dtype=torch.float): assert numpy.allclose(dcn.conv_offset.bias.grad.cpu().detach().numpy(), dcn_offset_b_grad, 1e-2) + def _test_amp_mdconv(self, input_dtype=torch.float): + """The function to test amp released on pytorch 1.6.0. + + The type of input data might be torch.float or torch.half, + so we should test mdconv in both cases. With amp, the data + type of model will NOT be set manually. + + Args: + input_dtype: torch.float or torch.half. + """ + if not torch.cuda.is_available(): + return + from mmcv.ops import ModulatedDeformConv2dPack + input = torch.tensor(input_t).cuda().type(input_dtype) + input.requires_grad = True + + dcn = ModulatedDeformConv2dPack( + 1, + 1, + kernel_size=(2, 2), + stride=1, + padding=1, + deform_groups=1, + bias=False).cuda() + dcn.weight.data.fill_(1.) + output = dcn(input) + output.sum().backward() + assert numpy.allclose(output.cpu().detach().numpy(), output_t, 1e-2) + assert numpy.allclose(input.grad.cpu().detach().numpy(), input_grad, + 1e-2) + assert numpy.allclose(dcn.weight.grad.cpu().detach().numpy(), + dcn_w_grad, 1e-2) + assert numpy.allclose( + dcn.conv_offset.weight.grad.cpu().detach().numpy(), + dcn_offset_w_grad, 1e-2) + assert numpy.allclose(dcn.conv_offset.bias.grad.cpu().detach().numpy(), + dcn_offset_b_grad, 1e-2) + def test_mdconv(self): self._test_mdconv(torch.double) self._test_mdconv(torch.float) self._test_mdconv(torch.half) + + # test amp when torch version >= '1.6.0', the type of + # input data for mdconv might be torch.float or torch.half + if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0': + with autocast(enabled=True): + self._test_amp_mdconv(torch.float) + self._test_amp_mdconv(torch.half)