From ac84f549c1425683dd97055ec1e07c231488c3dd Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 22 May 2023 12:28:29 +0800 Subject: [PATCH 1/8] add fast_conv_bn_eval option in ConvModule for fast validation and training in Eval mode --- mmcv/cnn/bricks/conv_module.py | 86 +++++++++++++++++++++++++++++- tests/test_cnn/test_conv_module.py | 10 ++++ 2 files changed, 94 insertions(+), 2 deletions(-) diff --git a/mmcv/cnn/bricks/conv_module.py b/mmcv/cnn/bricks/conv_module.py index 1f8e160517..7896347776 100644 --- a/mmcv/cnn/bricks/conv_module.py +++ b/mmcv/cnn/bricks/conv_module.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import warnings +from functools import partial from typing import Dict, Optional, Tuple, Union import torch @@ -14,6 +15,57 @@ from .padding import build_padding_layer +def fast_conv_bn_eval_forward(bn: _BatchNorm, conv: nn.modules.conv._ConvNd, + x: torch.Tensor): + """ + Implementation based on https://arxiv.org/abs/2305.11624 + "Tune-Mode ConvBN Blocks For Efficient Transfer Learning" + It leverages the associative law between convolution and affine transform, + i.e., normalize (weight conv feature) = (normalize weight) conv feature. + It works for Eval mode of ConvBN blocks during validation, and can be used + for training as well. It reduces memory and computation cost. + + Args: + bn (_BatchNorm): a BatchNorm module. + conv (nn._ConvNd): a conv module + x (torch.Tensor): Input feature map. + """ + # These lines of code are designed to deal with various cases + # like bn without affine transform, and conv without bias + weight_on_the_fly = conv.weight + if conv.bias is not None: + bias_on_the_fly = conv.bias + else: + bias_on_the_fly = torch.zeros_like(bn.running_var) + + if bn.weight is not None: + bn_weight = bn.weight + else: + bn_weight = torch.ones_like(bn.running_var) + + if bn.bias is not None: + bn_bias = bn.bias + else: + bn_bias = torch.zeros_like(bn.running_var) + + weight_coeff = torch.rsqrt(bn.running_var + + bn.eps) # shape of [C_out] in Conv2d + weight_coeff = torch.tensor( + weight_coeff.reshape([-1] + [1] * (len(conv.weight.shape) - 1)) + ) # shape of [C_out, 1, 1, 1] in Conv2d + coefff_on_the_fly = bn_weight.view_as( + weight_coeff) * weight_coeff # shape of [C_out, 1, 1, 1] + + # shape of [C_out, C_in, k, k] in Conv2d + weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly + bias_on_the_fly = ( + bias_on_the_fly - bn.running_mean + ) * coefff_on_the_fly.flatten() + bn_bias # shape of [C_out] + + return conv.__class__._conv_forward(conv, x, weight_on_the_fly, + bias_on_the_fly) + + @MODELS.register_module() class ConvModule(nn.Module): """A conv block that bundles conv/norm/activation layers. @@ -56,6 +108,9 @@ class ConvModule(nn.Module): Default: True. with_spectral_norm (bool): Whether use spectral norm in conv module. Default: False. + fast_conv_bn_eval (bool): Whether use fast conv when the consecutive + bn is in eval mode (either training or testing), as proposed in + https://arxiv.org/abs/2305.11624 . Default: False. padding_mode (str): If the `padding_mode` has not been supported by current `Conv2d` in PyTorch, we will use our own padding layer instead. Currently, we support ['zeros', 'circular'] with official @@ -83,6 +138,7 @@ def __init__(self, act_cfg: Optional[Dict] = dict(type='ReLU'), inplace: bool = True, with_spectral_norm: bool = False, + fast_conv_bn_eval: bool = False, padding_mode: str = 'zeros', order: tuple = ('conv', 'norm', 'act')): super().__init__() @@ -155,6 +211,17 @@ def __init__(self, else: self.norm_name = None # type: ignore + # fast_conv_bn_eval works for conv + bn + # with `track_running_stats` option + if fast_conv_bn_eval and self.norm and isinstance( + self.norm, _BatchNorm) and self.norm.track_running_stats: + self.fast_conv_bn_eval_forward = partial(fast_conv_bn_eval_forward, + self.norm, self.conv) + else: + self.fast_conv_bn_eval_forward = None # type: ignore + self.original_conv_forward = partial(self.conv.__class__.forward, + self.conv) + # build activation layer if self.with_activation: act_cfg_ = act_cfg.copy() # type: ignore @@ -200,13 +267,28 @@ def forward(self, x: torch.Tensor, activate: bool = True, norm: bool = True) -> torch.Tensor: - for layer in self.order: + layer_index = 0 + while layer_index < len(self.order): + layer = self.order[layer_index] if layer == 'conv': if self.with_explicit_padding: x = self.padding_layer(x) - x = self.conv(x) + # if the next operation is norm and we have a norm layer in + # eval mode and we have enabled fast_conv_bn_eval for the conv + # operator, then activate the optimized forward and skip the + # next norm operator since it has been fused + if self.order[layer_index + 1] == 'norm' and norm and \ + self.with_norm and not self.norm.training and \ + self.fast_conv_bn_eval_forward is not None: + self.conv.forward = self.fast_conv_bn_eval_forward + x = self.conv(x) + layer_index += 1 + else: + self.conv.forward = self.original_conv_forward + x = self.conv(x) elif layer == 'norm' and norm and self.with_norm: x = self.norm(x) elif layer == 'act' and activate and self.with_activation: x = self.activate(x) + layer_index += 1 return x diff --git a/tests/test_cnn/test_conv_module.py b/tests/test_cnn/test_conv_module.py index d31167a743..af7fc25ec1 100644 --- a/tests/test_cnn/test_conv_module.py +++ b/tests/test_cnn/test_conv_module.py @@ -75,6 +75,16 @@ def test_conv_module(): output = conv(x) assert output.shape == (1, 8, 255, 255) + # conv + norm with fast mode + conv = ConvModule( + 3, 8, 2, norm_cfg=dict(type='BN'), fast_conv_bn_eval=True) + conv.norm.eval() + x = torch.rand(1, 3, 256, 256) + fast_mode_output = conv(x) + conv.conv.forward = conv.original_conv_forward + plain_implementation = conv.activate(conv.norm(conv.conv(x))) + assert torch.allclose(fast_mode_output, plain_implementation, atol=1e-5) + # conv + act conv = ConvModule(3, 8, 2) assert conv.with_activation From ae287f8157d667623b21bbc8db8d241e9484b1f8 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 25 May 2023 22:25:10 +0800 Subject: [PATCH 2/8] simplify code --- mmcv/cnn/bricks/conv_module.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mmcv/cnn/bricks/conv_module.py b/mmcv/cnn/bricks/conv_module.py index 7896347776..7627b1439f 100644 --- a/mmcv/cnn/bricks/conv_module.py +++ b/mmcv/cnn/bricks/conv_module.py @@ -219,8 +219,7 @@ def __init__(self, self.norm, self.conv) else: self.fast_conv_bn_eval_forward = None # type: ignore - self.original_conv_forward = partial(self.conv.__class__.forward, - self.conv) + self.original_conv_forward = self.conv.forward # build activation layer if self.with_activation: From b094a34ed81975e8995340e7f3438f2f59010fdd Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 25 May 2023 22:59:04 +0800 Subject: [PATCH 3/8] merge two self.conv(x) call into one --- mmcv/cnn/bricks/conv_module.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mmcv/cnn/bricks/conv_module.py b/mmcv/cnn/bricks/conv_module.py index 7627b1439f..c175e61d58 100644 --- a/mmcv/cnn/bricks/conv_module.py +++ b/mmcv/cnn/bricks/conv_module.py @@ -280,11 +280,10 @@ def forward(self, self.with_norm and not self.norm.training and \ self.fast_conv_bn_eval_forward is not None: self.conv.forward = self.fast_conv_bn_eval_forward - x = self.conv(x) layer_index += 1 else: self.conv.forward = self.original_conv_forward - x = self.conv(x) + x = self.conv(x) elif layer == 'norm' and norm and self.with_norm: x = self.norm(x) elif layer == 'act' and activate and self.with_activation: From 47fcbedc53ee3162bc20bf4e27a3307f641e144d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 26 May 2023 01:42:39 +0800 Subject: [PATCH 4/8] make fast_conv_bn_eval the last argument --- mmcv/cnn/bricks/conv_module.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mmcv/cnn/bricks/conv_module.py b/mmcv/cnn/bricks/conv_module.py index c175e61d58..8029294a19 100644 --- a/mmcv/cnn/bricks/conv_module.py +++ b/mmcv/cnn/bricks/conv_module.py @@ -108,9 +108,6 @@ class ConvModule(nn.Module): Default: True. with_spectral_norm (bool): Whether use spectral norm in conv module. Default: False. - fast_conv_bn_eval (bool): Whether use fast conv when the consecutive - bn is in eval mode (either training or testing), as proposed in - https://arxiv.org/abs/2305.11624 . Default: False. padding_mode (str): If the `padding_mode` has not been supported by current `Conv2d` in PyTorch, we will use our own padding layer instead. Currently, we support ['zeros', 'circular'] with official @@ -120,6 +117,9 @@ class ConvModule(nn.Module): sequence of "conv", "norm" and "act". Common examples are ("conv", "norm", "act") and ("act", "conv", "norm"). Default: ('conv', 'norm', 'act'). + fast_conv_bn_eval (bool): Whether use fast conv when the consecutive + bn is in eval mode (either training or testing), as proposed in + https://arxiv.org/abs/2305.11624 . Default: False. """ _abbr_ = 'conv_block' @@ -138,9 +138,9 @@ def __init__(self, act_cfg: Optional[Dict] = dict(type='ReLU'), inplace: bool = True, with_spectral_norm: bool = False, - fast_conv_bn_eval: bool = False, padding_mode: str = 'zeros', - order: tuple = ('conv', 'norm', 'act')): + order: tuple = ('conv', 'norm', 'act'), + fast_conv_bn_eval: bool = False): super().__init__() assert conv_cfg is None or isinstance(conv_cfg, dict) assert norm_cfg is None or isinstance(norm_cfg, dict) From 391bc1a69ff9c34d04ea1c7cd82d8ed147cc994a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 10 Jun 2023 12:58:59 +0800 Subject: [PATCH 5/8] add a static method to create ConvModule from a pair of existing conv and bn --- mmcv/cnn/bricks/conv_module.py | 49 ++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/mmcv/cnn/bricks/conv_module.py b/mmcv/cnn/bricks/conv_module.py index 8029294a19..20f9331fe8 100644 --- a/mmcv/cnn/bricks/conv_module.py +++ b/mmcv/cnn/bricks/conv_module.py @@ -290,3 +290,52 @@ def forward(self, x = self.activate(x) layer_index += 1 return x + + @staticmethod + def create_from_conv_bn(conv: torch.nn.modules.conv._ConvNd, + bn: torch.nn.modules.batchnorm._BatchNorm, + fast_conv_bn_eval=True) -> 'ConvModule': + """Create a ConvModule from a conv and a bn module.""" + self = ConvModule.__new__(ConvModule) + super(ConvModule, self).__init__() + + self.conv_cfg = None + self.norm_cfg = None + self.act_cfg = None + self.inplace = False + self.with_spectral_norm = False + self.with_explicit_padding = False + self.order = ('conv', 'norm', 'act') + + self.with_norm = True + self.with_activation = False + self.with_bias = conv.bias is not None + + # build convolution layer + self.conv = conv + # export the attributes of self.conv to a higher level for convenience + self.in_channels = self.conv.in_channels + self.out_channels = self.conv.out_channels + self.kernel_size = self.conv.kernel_size + self.stride = self.conv.stride + self.padding = self.conv.padding + self.dilation = self.conv.dilation + self.transposed = self.conv.transposed + self.output_padding = self.conv.output_padding + self.groups = self.conv.groups + + # build normalization layers + self.norm_name, norm = 'bn', bn + self.add_module(self.norm_name, norm) + + # fast_conv_bn_eval works for conv + bn + # with `track_running_stats` option + if fast_conv_bn_eval and self.norm and isinstance( + self.norm, _BatchNorm) and self.norm.track_running_stats: + self.fast_conv_bn_eval_forward = partial(fast_conv_bn_eval_forward, + self.norm, self.conv) + else: + self.fast_conv_bn_eval_forward = None # type: ignore + self.original_conv_forward = self.conv.forward + + return self From 51d1e454b26f985efd37557e3ddb973156a4974c Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 13 Jun 2023 19:38:25 +0800 Subject: [PATCH 6/8] avoid index out of range --- mmcv/cnn/bricks/conv_module.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mmcv/cnn/bricks/conv_module.py b/mmcv/cnn/bricks/conv_module.py index 20f9331fe8..9d1869da26 100644 --- a/mmcv/cnn/bricks/conv_module.py +++ b/mmcv/cnn/bricks/conv_module.py @@ -276,7 +276,8 @@ def forward(self, # eval mode and we have enabled fast_conv_bn_eval for the conv # operator, then activate the optimized forward and skip the # next norm operator since it has been fused - if self.order[layer_index + 1] == 'norm' and norm and \ + if layer_index + 1 < len(self.order) and \ + self.order[layer_index + 1] == 'norm' and norm and \ self.with_norm and not self.norm.training and \ self.fast_conv_bn_eval_forward is not None: self.conv.forward = self.fast_conv_bn_eval_forward From 822d2678a6579662792ffa1dafde8d9cc0781473 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 13 Jun 2023 20:10:44 +0800 Subject: [PATCH 7/8] simplify and beautify the code --- mmcv/cnn/bricks/conv_module.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/mmcv/cnn/bricks/conv_module.py b/mmcv/cnn/bricks/conv_module.py index 9d1869da26..9572e7d457 100644 --- a/mmcv/cnn/bricks/conv_module.py +++ b/mmcv/cnn/bricks/conv_module.py @@ -48,19 +48,18 @@ def fast_conv_bn_eval_forward(bn: _BatchNorm, conv: nn.modules.conv._ConvNd, else: bn_bias = torch.zeros_like(bn.running_var) + # shape of [C_out, 1, 1, 1] in Conv2d weight_coeff = torch.rsqrt(bn.running_var + - bn.eps) # shape of [C_out] in Conv2d - weight_coeff = torch.tensor( - weight_coeff.reshape([-1] + [1] * (len(conv.weight.shape) - 1)) - ) # shape of [C_out, 1, 1, 1] in Conv2d - coefff_on_the_fly = bn_weight.view_as( - weight_coeff) * weight_coeff # shape of [C_out, 1, 1, 1] + bn.eps).reshape([-1] + [1] * + (len(conv.weight.shape) - 1)) + # shape of [C_out, 1, 1, 1] in Conv2d + coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff # shape of [C_out, C_in, k, k] in Conv2d weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly - bias_on_the_fly = ( - bias_on_the_fly - bn.running_mean - ) * coefff_on_the_fly.flatten() + bn_bias # shape of [C_out] + # shape of [C_out] in Conv2d + bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() *\ + (bias_on_the_fly - bn.running_mean) return conv.__class__._conv_forward(conv, x, weight_on_the_fly, bias_on_the_fly) From 41784becdef8a017482d5870be64cdb7809c6d42 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 13 Jun 2023 20:17:56 +0800 Subject: [PATCH 8/8] simplify code usage of conv._conv_forward --- mmcv/cnn/bricks/conv_module.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mmcv/cnn/bricks/conv_module.py b/mmcv/cnn/bricks/conv_module.py index 9572e7d457..a8a55ff316 100644 --- a/mmcv/cnn/bricks/conv_module.py +++ b/mmcv/cnn/bricks/conv_module.py @@ -61,8 +61,7 @@ def fast_conv_bn_eval_forward(bn: _BatchNorm, conv: nn.modules.conv._ConvNd, bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() *\ (bias_on_the_fly - bn.running_mean) - return conv.__class__._conv_forward(conv, x, weight_on_the_fly, - bias_on_the_fly) + return conv._conv_forward(x, weight_on_the_fly, bias_on_the_fly) @MODELS.register_module()