Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add fast_conv_bn_eval option in ConvModule for fast validation and training in Eval mode #2807

Merged
merged 8 commits into from
Jun 13, 2023
134 changes: 132 additions & 2 deletions mmcv/cnn/bricks/conv_module.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from functools import partial
from typing import Dict, Optional, Tuple, Union

import torch
Expand All @@ -14,6 +15,57 @@
from .padding import build_padding_layer


def fast_conv_bn_eval_forward(bn: _BatchNorm, conv: nn.modules.conv._ConvNd,
x: torch.Tensor):
"""
Implementation based on https://arxiv.org/abs/2305.11624
"Tune-Mode ConvBN Blocks For Efficient Transfer Learning"
It leverages the associative law between convolution and affine transform,
i.e., normalize (weight conv feature) = (normalize weight) conv feature.
It works for Eval mode of ConvBN blocks during validation, and can be used
for training as well. It reduces memory and computation cost.

Args:
bn (_BatchNorm): a BatchNorm module.
conv (nn._ConvNd): a conv module
x (torch.Tensor): Input feature map.
"""
# These lines of code are designed to deal with various cases
# like bn without affine transform, and conv without bias
weight_on_the_fly = conv.weight
if conv.bias is not None:
bias_on_the_fly = conv.bias
else:
bias_on_the_fly = torch.zeros_like(bn.running_var)

if bn.weight is not None:
bn_weight = bn.weight
else:
bn_weight = torch.ones_like(bn.running_var)

if bn.bias is not None:
bn_bias = bn.bias
else:
bn_bias = torch.zeros_like(bn.running_var)

weight_coeff = torch.rsqrt(bn.running_var +
bn.eps) # shape of [C_out] in Conv2d
weight_coeff = torch.tensor(
weight_coeff.reshape([-1] + [1] * (len(conv.weight.shape) - 1))
) # shape of [C_out, 1, 1, 1] in Conv2d
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
coefff_on_the_fly = bn_weight.view_as(
weight_coeff) * weight_coeff # shape of [C_out, 1, 1, 1]

# shape of [C_out, C_in, k, k] in Conv2d
weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly
bias_on_the_fly = (
bias_on_the_fly - bn.running_mean
) * coefff_on_the_fly.flatten() + bn_bias # shape of [C_out]

return conv.__class__._conv_forward(conv, x, weight_on_the_fly,
bias_on_the_fly)
youkaichao marked this conversation as resolved.
Show resolved Hide resolved


@MODELS.register_module()
class ConvModule(nn.Module):
"""A conv block that bundles conv/norm/activation layers.
Expand Down Expand Up @@ -65,6 +117,9 @@ class ConvModule(nn.Module):
sequence of "conv", "norm" and "act". Common examples are
("conv", "norm", "act") and ("act", "conv", "norm").
Default: ('conv', 'norm', 'act').
fast_conv_bn_eval (bool): Whether use fast conv when the consecutive
bn is in eval mode (either training or testing), as proposed in
https://arxiv.org/abs/2305.11624 . Default: False.
"""

_abbr_ = 'conv_block'
Expand All @@ -84,7 +139,8 @@ def __init__(self,
inplace: bool = True,
with_spectral_norm: bool = False,
padding_mode: str = 'zeros',
order: tuple = ('conv', 'norm', 'act')):
order: tuple = ('conv', 'norm', 'act'),
fast_conv_bn_eval: bool = False):
super().__init__()
assert conv_cfg is None or isinstance(conv_cfg, dict)
assert norm_cfg is None or isinstance(norm_cfg, dict)
Expand Down Expand Up @@ -155,6 +211,16 @@ def __init__(self,
else:
self.norm_name = None # type: ignore

# fast_conv_bn_eval works for conv + bn
# with `track_running_stats` option
if fast_conv_bn_eval and self.norm and isinstance(
self.norm, _BatchNorm) and self.norm.track_running_stats:
self.fast_conv_bn_eval_forward = partial(fast_conv_bn_eval_forward,
self.norm, self.conv)
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
else:
self.fast_conv_bn_eval_forward = None # type: ignore
self.original_conv_forward = self.conv.forward

# build activation layer
if self.with_activation:
act_cfg_ = act_cfg.copy() # type: ignore
Expand Down Expand Up @@ -200,13 +266,77 @@ def forward(self,
x: torch.Tensor,
activate: bool = True,
norm: bool = True) -> torch.Tensor:
for layer in self.order:
layer_index = 0
while layer_index < len(self.order):
layer = self.order[layer_index]
if layer == 'conv':
if self.with_explicit_padding:
x = self.padding_layer(x)
# if the next operation is norm and we have a norm layer in
# eval mode and we have enabled fast_conv_bn_eval for the conv
# operator, then activate the optimized forward and skip the
# next norm operator since it has been fused
if layer_index + 1 < len(self.order) and \
self.order[layer_index + 1] == 'norm' and norm and \
self.with_norm and not self.norm.training and \
self.fast_conv_bn_eval_forward is not None:
self.conv.forward = self.fast_conv_bn_eval_forward
layer_index += 1
else:
self.conv.forward = self.original_conv_forward
x = self.conv(x)
elif layer == 'norm' and norm and self.with_norm:
x = self.norm(x)
elif layer == 'act' and activate and self.with_activation:
x = self.activate(x)
layer_index += 1
return x

@staticmethod
def create_from_conv_bn(conv: torch.nn.modules.conv._ConvNd,
bn: torch.nn.modules.batchnorm._BatchNorm,
fast_conv_bn_eval=True) -> 'ConvModule':
"""Create a ConvModule from a conv and a bn module."""
self = ConvModule.__new__(ConvModule)
super(ConvModule, self).__init__()

self.conv_cfg = None
self.norm_cfg = None
self.act_cfg = None
self.inplace = False
self.with_spectral_norm = False
self.with_explicit_padding = False
self.order = ('conv', 'norm', 'act')

self.with_norm = True
self.with_activation = False
self.with_bias = conv.bias is not None

# build convolution layer
self.conv = conv
# export the attributes of self.conv to a higher level for convenience
self.in_channels = self.conv.in_channels
self.out_channels = self.conv.out_channels
self.kernel_size = self.conv.kernel_size
self.stride = self.conv.stride
self.padding = self.conv.padding
self.dilation = self.conv.dilation
self.transposed = self.conv.transposed
self.output_padding = self.conv.output_padding
self.groups = self.conv.groups

# build normalization layers
self.norm_name, norm = 'bn', bn
self.add_module(self.norm_name, norm)

# fast_conv_bn_eval works for conv + bn
# with `track_running_stats` option
if fast_conv_bn_eval and self.norm and isinstance(
self.norm, _BatchNorm) and self.norm.track_running_stats:
self.fast_conv_bn_eval_forward = partial(fast_conv_bn_eval_forward,
self.norm, self.conv)
else:
self.fast_conv_bn_eval_forward = None # type: ignore
self.original_conv_forward = self.conv.forward

return self
10 changes: 10 additions & 0 deletions tests/test_cnn/test_conv_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@ def test_conv_module():
output = conv(x)
assert output.shape == (1, 8, 255, 255)

# conv + norm with fast mode
conv = ConvModule(
3, 8, 2, norm_cfg=dict(type='BN'), fast_conv_bn_eval=True)
conv.norm.eval()
x = torch.rand(1, 3, 256, 256)
fast_mode_output = conv(x)
conv.conv.forward = conv.original_conv_forward
plain_implementation = conv.activate(conv.norm(conv.conv(x)))
assert torch.allclose(fast_mode_output, plain_implementation, atol=1e-5)

# conv + act
conv = ConvModule(3, 8, 2)
assert conv.with_activation
Expand Down