From b028c351e6d03ada026d7887c82355a366e4c3a1 Mon Sep 17 00:00:00 2001 From: AronLin <347630870@qq.com> Date: Mon, 17 May 2021 15:30:33 +0800 Subject: [PATCH 1/5] fix fp16 bug on DCNv2 --- mmcv/ops/modulated_deform_conv.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mmcv/ops/modulated_deform_conv.py b/mmcv/ops/modulated_deform_conv.py index b8ff1adeb2..2a9b60d624 100644 --- a/mmcv/ops/modulated_deform_conv.py +++ b/mmcv/ops/modulated_deform_conv.py @@ -57,6 +57,9 @@ def forward(ctx, ctx.with_bias = bias is not None if not ctx.with_bias: bias = input.new_empty(0) # fake tensor + # until the code is modified for torch.cuda.amp.autocast, + # we need to cast weight to avoid type mismatch in fp16 training + weight = weight.type_as(input) ctx.save_for_backward(input, offset, mask, weight, bias) output = input.new_empty( ModulatedDeformConv2dFunction._output_size(ctx, input, weight)) From d9dc8f56df947a88c17e0492e7ebb3a67c50fa42 Mon Sep 17 00:00:00 2001 From: aronlin <347630870@qq.com> Date: Wed, 19 May 2021 17:38:23 +0800 Subject: [PATCH 2/5] support fp16 on DCN/DCNv2 when pytorch >= '1.6.0' --- mmcv/ops/deform_conv.py | 7 +- mmcv/ops/modulated_deform_conv.py | 7 +- tests/test_ops/test_deform_conv.py | 70 ++++++++++++++++++++ tests/test_ops/test_modulated_deform_conv.py | 54 +++++++++++++++ 4 files changed, 134 insertions(+), 4 deletions(-) diff --git a/mmcv/ops/deform_conv.py b/mmcv/ops/deform_conv.py index 5282e26193..3c8de84c87 100644 --- a/mmcv/ops/deform_conv.py +++ b/mmcv/ops/deform_conv.py @@ -70,8 +70,11 @@ def forward(ctx, ctx.deform_groups = deform_groups ctx.im2col_step = im2col_step - # until the code is modified for torch.cuda.amp.autocast, - # we need to cast weight to avoid type mismatch in fp16 training + # The flag for whether to use fp16 (pytorch < 1.6.0) or + # map (pytorch >= 1.6.0) is the type of "offset", we + # cast weight and input to temporarily support fp16 and + # amp whatever the pytorch version is. + input = input.to(offset.dtype) weight = weight.type_as(input) ctx.save_for_backward(input, offset, weight) diff --git a/mmcv/ops/modulated_deform_conv.py b/mmcv/ops/modulated_deform_conv.py index 2a9b60d624..fe9816afb7 100644 --- a/mmcv/ops/modulated_deform_conv.py +++ b/mmcv/ops/modulated_deform_conv.py @@ -57,8 +57,11 @@ def forward(ctx, ctx.with_bias = bias is not None if not ctx.with_bias: bias = input.new_empty(0) # fake tensor - # until the code is modified for torch.cuda.amp.autocast, - # we need to cast weight to avoid type mismatch in fp16 training + # The flag for whether to use fp16 (pytorch < 1.6.0) or + # map (pytorch >= 1.6.0) is the type of "offset", we + # cast weight and input to temporarily support fp16 and + # amp whatever the pytorch version is. + input = input.to(offset.dtype) weight = weight.type_as(input) ctx.save_for_backward(input, offset, mask, weight, bias) output = input.new_empty( diff --git a/tests/test_ops/test_deform_conv.py b/tests/test_ops/test_deform_conv.py index b99df8d011..c49d47980a 100644 --- a/tests/test_ops/test_deform_conv.py +++ b/tests/test_ops/test_deform_conv.py @@ -2,6 +2,15 @@ import pytest import torch +from mmcv.utils import TORCH_VERSION + +try: + # If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast + # would be imported and used; we should test if our modules support it. + from torch.cuda.amp import autocast +except ImportError: + pass + input = [[[[1., 2., 3.], [0., 1., 2.], [3., 5., 2.]]]] offset_weight = [[[0.1, 0.4, 0.6, 0.1]], [[0.3, 0.2, 0.1, 0.3]], [[0.5, 0.5, 0.2, 0.8]], [[0.8, 0.3, 0.9, 0.1]], @@ -71,7 +80,68 @@ def _test_deformconv(self, dtype=torch.float, threshold=1e-3): with pytest.raises(AssertionError): model = DeformConv2d(3, 4, 3, groups=3) + def _test_amp_deformconv(self, input_dtype, threshold=1e-3): + """The function to test amp released on pytorch 1.6.0. + + The type of input data might be torch.float or torch.half, + so we should test deform_conv in both cases. With amp, the + data type of model will NOT be set manually. + + Args: + input_dtype: torch.float or torch.half. + threshold: the same as above function. + """ + if not torch.cuda.is_available(): + return + from mmcv.ops import DeformConv2dPack + c_in = 1 + c_out = 1 + x = torch.Tensor(input).cuda().type(input_dtype) + x.requires_grad = True + model = DeformConv2dPack(c_in, c_out, 2, stride=1, padding=0) + model.conv_offset.weight.data = torch.nn.Parameter( + torch.Tensor(offset_weight).reshape(8, 1, 2, 2)) + model.conv_offset.bias.data = torch.nn.Parameter( + torch.Tensor(offset_bias).reshape(8)) + model.weight.data = torch.nn.Parameter( + torch.Tensor(deform_weight).reshape(1, 1, 2, 2)) + model.cuda() + + out = model(x) + out.backward(torch.ones_like(out)) + + assert np.allclose(out.data.detach().cpu().numpy(), gt_out, threshold) + assert np.allclose(x.grad.detach().cpu().numpy(), gt_x_grad, threshold) + assert np.allclose( + model.conv_offset.weight.grad.detach().cpu().numpy(), + gt_offset_weight_grad, threshold) + assert np.allclose(model.conv_offset.bias.grad.detach().cpu().numpy(), + gt_offset_bias_grad, threshold) + assert np.allclose(model.weight.grad.detach().cpu().numpy(), + gt_deform_weight_grad, threshold) + + from mmcv.ops import DeformConv2d + # test bias + model = DeformConv2d(1, 1, 2, stride=1, padding=0) + assert not hasattr(model, 'bias') + # test bias=True + with pytest.raises(AssertionError): + model = DeformConv2d(1, 1, 2, stride=1, padding=0, bias=True) + # test in_channels % group != 0 + with pytest.raises(AssertionError): + model = DeformConv2d(3, 2, 3, groups=2) + # test out_channels % group != 0 + with pytest.raises(AssertionError): + model = DeformConv2d(3, 4, 3, groups=3) + def test_deformconv(self): self._test_deformconv(torch.double) self._test_deformconv(torch.float) self._test_deformconv(torch.half, 1e-1) + + # test amp when torch version >= '1.6.0', the type of + # input data for deformconv might be torch.float or torch.half + if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0': + with autocast(enabled=True): + self._test_amp_deformconv(torch.float, 1e-1) + self._test_amp_deformconv(torch.half, 1e-1) diff --git a/tests/test_ops/test_modulated_deform_conv.py b/tests/test_ops/test_modulated_deform_conv.py index 43ddd66707..83c6f8a405 100644 --- a/tests/test_ops/test_modulated_deform_conv.py +++ b/tests/test_ops/test_modulated_deform_conv.py @@ -3,6 +3,15 @@ import numpy import torch +from mmcv.utils import TORCH_VERSION + +try: + # If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast + # would be imported and used; we should test if our modules support it. + from torch.cuda.amp import autocast +except ImportError: + pass + cur_dir = os.path.dirname(os.path.abspath(__file__)) input_t = [[[[1., 2., 3.], [1., 2., 3.], [1., 2., 3.]]]] @@ -58,7 +67,52 @@ def _test_mdconv(self, dtype=torch.float): assert numpy.allclose(dcn.conv_offset.bias.grad.cpu().detach().numpy(), dcn_offset_b_grad, 1e-2) + def _test_amp_mdconv(self, input_dtype=torch.float): + """The function to test amp released on pytorch 1.6.0. + + The type of input data might be torch.float or torch.half, + so we should test mdconv in both cases. With amp, the data + type of model will NOT be set manually. + + Args: + input_dtype: torch.float or torch.half. + """ + if not torch.cuda.is_available(): + return + from mmcv.ops import ModulatedDeformConv2dPack + input = torch.tensor(input_t).cuda().type(input_dtype) + input.requires_grad = True + + dcn = ModulatedDeformConv2dPack( + 1, + 1, + kernel_size=(2, 2), + stride=1, + padding=1, + deform_groups=1, + bias=False).cuda() + dcn.weight.data.fill_(1.) + output = dcn(input) + output.sum().backward() + assert numpy.allclose(output.cpu().detach().numpy(), output_t, 1e-2) + assert numpy.allclose(input.grad.cpu().detach().numpy(), input_grad, + 1e-2) + assert numpy.allclose(dcn.weight.grad.cpu().detach().numpy(), + dcn_w_grad, 1e-2) + assert numpy.allclose( + dcn.conv_offset.weight.grad.cpu().detach().numpy(), + dcn_offset_w_grad, 1e-2) + assert numpy.allclose(dcn.conv_offset.bias.grad.cpu().detach().numpy(), + dcn_offset_b_grad, 1e-2) + def test_mdconv(self): self._test_mdconv(torch.double) self._test_mdconv(torch.float) self._test_mdconv(torch.half) + + # test amp when torch version >= '1.6.0', the type of + # input data for mdconv might be torch.float or torch.half + if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0': + with autocast(enabled=True): + self._test_amp_mdconv(torch.float) + self._test_amp_mdconv(torch.half) From bb0d9eaab0122f1875b0fe65227261034be5fd24 Mon Sep 17 00:00:00 2001 From: AronLin <347630870@qq.com> Date: Thu, 20 May 2021 11:49:47 +0800 Subject: [PATCH 3/5] add comment --- mmcv/ops/deform_conv.py | 2 +- mmcv/ops/modulated_deform_conv.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mmcv/ops/deform_conv.py b/mmcv/ops/deform_conv.py index 3c8de84c87..923c9d10c5 100644 --- a/mmcv/ops/deform_conv.py +++ b/mmcv/ops/deform_conv.py @@ -71,7 +71,7 @@ def forward(ctx, ctx.im2col_step = im2col_step # The flag for whether to use fp16 (pytorch < 1.6.0) or - # map (pytorch >= 1.6.0) is the type of "offset", we + # amp (pytorch >= 1.6.0) is the type of "offset", we # cast weight and input to temporarily support fp16 and # amp whatever the pytorch version is. input = input.to(offset.dtype) diff --git a/mmcv/ops/modulated_deform_conv.py b/mmcv/ops/modulated_deform_conv.py index fe9816afb7..7268272a5b 100644 --- a/mmcv/ops/modulated_deform_conv.py +++ b/mmcv/ops/modulated_deform_conv.py @@ -58,7 +58,7 @@ def forward(ctx, if not ctx.with_bias: bias = input.new_empty(0) # fake tensor # The flag for whether to use fp16 (pytorch < 1.6.0) or - # map (pytorch >= 1.6.0) is the type of "offset", we + # amp (pytorch >= 1.6.0) is the type of "offset", we # cast weight and input to temporarily support fp16 and # amp whatever the pytorch version is. input = input.to(offset.dtype) From 97ead4dd985e7bc41192488dd37a55b8d5b5c9b6 Mon Sep 17 00:00:00 2001 From: AronLin <347630870@qq.com> Date: Fri, 21 May 2021 11:48:33 +0800 Subject: [PATCH 4/5] Modified the comments --- mmcv/ops/deform_conv.py | 11 +++++++---- mmcv/ops/modulated_deform_conv.py | 11 +++++++---- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/mmcv/ops/deform_conv.py b/mmcv/ops/deform_conv.py index 923c9d10c5..a808774245 100644 --- a/mmcv/ops/deform_conv.py +++ b/mmcv/ops/deform_conv.py @@ -70,10 +70,13 @@ def forward(ctx, ctx.deform_groups = deform_groups ctx.im2col_step = im2col_step - # The flag for whether to use fp16 (pytorch < 1.6.0) or - # amp (pytorch >= 1.6.0) is the type of "offset", we - # cast weight and input to temporarily support fp16 and - # amp whatever the pytorch version is. + # When pytorch version >= 1.6.0, amp is adopted for fp16 mode; + # amp won't cast the type of model (float32), but "offset" is cast + # to float16 by nn.Conv2d automatically, leading to the type + # mismatch with input (when it is float32) or weight. + # The flag for whether to use fp16 or amp is the type of "offset", + # we cast weight and input to temporarily support fp16 and amp + # whatever the pytorch version is. input = input.to(offset.dtype) weight = weight.type_as(input) ctx.save_for_backward(input, offset, weight) diff --git a/mmcv/ops/modulated_deform_conv.py b/mmcv/ops/modulated_deform_conv.py index 7268272a5b..acbe69bcb6 100644 --- a/mmcv/ops/modulated_deform_conv.py +++ b/mmcv/ops/modulated_deform_conv.py @@ -57,10 +57,13 @@ def forward(ctx, ctx.with_bias = bias is not None if not ctx.with_bias: bias = input.new_empty(0) # fake tensor - # The flag for whether to use fp16 (pytorch < 1.6.0) or - # amp (pytorch >= 1.6.0) is the type of "offset", we - # cast weight and input to temporarily support fp16 and - # amp whatever the pytorch version is. + # When pytorch version >= 1.6.0, amp is adopted for fp16 mode; + # amp won't cast the type of model (float32), but "offset" is cast + # to float16 by nn.Conv2d automatically, leading to the type + # mismatch with input (when it is float32) or weight. + # The flag for whether to use fp16 or amp is the type of "offset", + # we cast weight and input to temporarily support fp16 and amp + # whatever the pytorch version is. input = input.to(offset.dtype) weight = weight.type_as(input) ctx.save_for_backward(input, offset, mask, weight, bias) From 6c719c0d4d367ba47c98b83149ace34be59c53b1 Mon Sep 17 00:00:00 2001 From: AronLin <347630870@qq.com> Date: Sun, 23 May 2021 16:01:05 +0800 Subject: [PATCH 5/5] Unified the usages of '.to()' and '.type_as()' --- mmcv/ops/deform_conv.py | 2 +- mmcv/ops/modulated_deform_conv.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mmcv/ops/deform_conv.py b/mmcv/ops/deform_conv.py index a808774245..04666f58db 100644 --- a/mmcv/ops/deform_conv.py +++ b/mmcv/ops/deform_conv.py @@ -77,7 +77,7 @@ def forward(ctx, # The flag for whether to use fp16 or amp is the type of "offset", # we cast weight and input to temporarily support fp16 and amp # whatever the pytorch version is. - input = input.to(offset.dtype) + input = input.type_as(offset) weight = weight.type_as(input) ctx.save_for_backward(input, offset, weight) diff --git a/mmcv/ops/modulated_deform_conv.py b/mmcv/ops/modulated_deform_conv.py index acbe69bcb6..b3dfd0b003 100644 --- a/mmcv/ops/modulated_deform_conv.py +++ b/mmcv/ops/modulated_deform_conv.py @@ -64,7 +64,7 @@ def forward(ctx, # The flag for whether to use fp16 or amp is the type of "offset", # we cast weight and input to temporarily support fp16 and amp # whatever the pytorch version is. - input = input.to(offset.dtype) + input = input.type_as(offset) weight = weight.type_as(input) ctx.save_for_backward(input, offset, mask, weight, bias) output = input.new_empty(