-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Add fast_conv_bn_eval option in ConvModule for fast validation and training in Eval mode #2807
Conversation
Another related implementation is FrozenBatchNorm2d from torchvision and detectron2. Implementation of this PR is faster than FrozenBatchNorm2d, with almost the same memory cost (significantly less than current ConvModule in mmcv). The table is from the Table 8 of the paper "Tune-Mode ConvBN Blocks For Efficient Transfer Learning": Besides, this PR does not hurt performance, while FrozenBatchNorm2d will. From Table 6 of the MMDetection report, FrozenBatchNorm2d is worse in mAP. While this PR is equivalent with the norm_eval setting. From the Figure 1 of the paper "Tune-Mode ConvBN Blocks For Efficient Transfer Learning", norm_eval is prevalent in MMDetection: Therefore, I think this PR can be a drop-in improvement for mmcv. It automatically identifies the case for possible acceleration with equivalent implementation. |
…aining in Eval mode
The implementation is compatible with ONNX export, and since |
Here is an example usage: # Import required libraries
from typing import Tuple
from functools import partial
from operator import attrgetter
import torch
import torch.nn as nn
import torch.fx as fx
from mmcv.cnn import ConvModule
# Helper function to split a qualname into parent path and last atom.
def _parent_name(target : str) -> Tuple[str, str]:
"""
Splits a qualname into parent path and last atom.
For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
"""
*parent, name = target.rsplit('.', 1)
return parent[0] if parent else '', name
def replace_sub_module(model, name, new_module):
# Remove the original module from the model
# usage: replace_sub_module(model, 'layer1.block2.conv2', conv)
parent_name, name = _parent_name(name)
if parent_name != '':
getter = attrgetter(parent_name)
parent = getter(model)
else:
parent = model
setattr(parent, name, new_module)
# Main function to merge consecutive conv+bn into ConvModule for the given model
def find_and_merge_conv_bn(model: torch.nn.Module):
# Symbolically trace the input model to create an FX GraphModule
fx_model: fx.GraphModule = fx.symbolic_trace(model)
modules = dict(fx_model.named_modules())
patterns = [(torch.nn.modules.conv._ConvNd, torch.nn.modules.batchnorm._BatchNorm)]
# Iterate through nodes in the graph to find ConvBN blocks
for node in fx_model.graph.nodes:
if node.op != 'call_module': # If our current node isn't calling a Module then we can ignore it.
continue
found_pair = [node for conv_class, bn_class in patterns if isinstance(modules[node.target], bn_class) and isinstance(modules[node.args[0].target], conv_class)]
if not found_pair or len(node.args[0].users) > 1: # Not a conv-BN pattern or output of conv is used by other nodes
continue
# Find a pair of conv and bn to optimize
conv_name = node.args[0].target
bn_name = node.target
print(f'Merging {conv_name} and {bn_name} into a ConvModule')
conv = modules[conv_name]
bn = modules[bn_name]
# Fuse conv and bn into a ConvModule
new_conv = ConvModule.create_from_conv_bn(conv, bn)
replace_sub_module(model, conv_name, new_conv)
replace_sub_module(model, bn_name, nn.Identity())
if __name__ == '__main__':
import torchvision.models as models
from copy import deepcopy
resnet = models.resnet50(pretrained=False)
resnet.eval()
resnet2 = deepcopy(resnet)
resnet2.eval()
find_and_merge_conv_bn(resnet2)
resnet.cuda()
resnet2.cuda()
input = torch.randn(32, 3, 224, 224).cuda()
output = resnet(input)
output2 = resnet2(input)
print(torch.allclose(output, output2, atol=1e-4))
del output
del output2
import time
start = time.time()
# reset pytorch max_memory_allocated
torch.cuda.reset_max_memory_allocated()
start_memory = torch.cuda.memory_allocated()
for i in range(10):
resnet(input).sum().backward()
end = time.time()
max_memory = torch.cuda.max_memory_allocated()
print(f'time for resnet: {end - start} seconds (10 batches with batch size 32)')
print(f'max memory for resnet: {(max_memory - start_memory) / 1024 ** 3} GB')
start = time.time()
# reset pytorch max_memory_allocated
torch.cuda.reset_max_memory_allocated()
start_memory = torch.cuda.memory_allocated()
for i in range(10):
resnet2(input).sum().backward()
end = time.time()
max_memory = torch.cuda.max_memory_allocated()
print(f'time for resnet with ConvModule: {end - start} seconds (10 batches with batch size 32)')
print(f'max memory for resnet with ConvModule: {(max_memory - start_memory) / 1024 ** 3} GB') On my server with RTX 2080 Ti GPU, the output is :
Merging conv and bn into a ConvModule with Update: I re-run the test, with the following results:
The memory reduction is obvious, but the time reduction is not that obvious. The wallclock time can vary from time to time. It is not very stable. |
Thanks for your guidance, and I've test it with The training accuracy matches the original result well: DONE (t=10.26s).
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.365
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=1000 ] = 0.555
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=1000 ] = 0.389
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.205
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = 0.400
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.481
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.538
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=300 ] = 0.538
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=1000 ] = 0.538
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.333
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = 0.582
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.691
06/12 13:47:05 - mmengine - INFO - bbox_mAP_copypaste: 0.365 0.555 0.389 0.205 0.400 0.481
06/12 13:47:06 - mmengine - INFO - Epoch(val) [12][625/625] coco/bbox_mAP: 0.3650 coco/bbox_mAP_50: 0.5550 coco/bbox_mAP_75: 0.3890 coco/bbox_mAP_s: 0.2050 coco/bbox_mAP_m: 0.4000 coco/bbox_mAP_l: 0.4810 data_time: 0.0019 time: 0.0268 Besides, the memory optimization is also obvious: Result of fast conv-bn
Result of normal conv-bn 2023/06/12 11:14:35 - mmengine - INFO - Epoch(train) [1][ 50/7330] lr: 9.9098e-04 eta: 6:11:53 time: 0.2538 data_time: 0.0058 memory: 3306 loss: 1.9298 loss_cls: 1.2243 loss_bbox: 0.7054
2023/06/12 11:14:41 - mmengine - INFO - Epoch(train) [1][ 100/7330] lr: 1.9920e-03 eta: 4:33:11 time: 0.1193 data_time: 0.0033 memory: 3303 loss: 1.8993 loss_cls: 1.2241 loss_bbox: 0.6752
2023/06/12 11:14:47 - mmengine - INFO - Epoch(train) [1][ 150/7330] lr: 2.9930e-03 eta: 3:53:54 time: 0.1064 data_time: 0.0033 memory: 3301 loss: 1.9179 loss_cls: 1.2271 loss_bbox: 0.6908
2023/06/12 11:14:52 - mmengine - INFO - Epoch(train) [1][ 200/7330] lr: 3.9940e-03 eta: 3:35:22 time: 0.1095 data_time: 0.0032 memory: 3306 loss: 1.9083 loss_cls: 1.2412 loss_bbox: 0.6671
2023/06/12 11:14:57 - mmengine - INFO - Epoch(train) [1][ 250/7330] lr: 4.9950e-03 eta: 3:20:48 time: 0.0978 data_time: 0.0033 memory: 3305 loss: 1.7696 loss_cls: 1.1125 loss_bbox: 0.6571
2023/06/12 11:15:02 - mmengine - INFO - Epoch(train) [1][ 300/7330] lr: 5.9960e-03 eta: 3:11:37 time: 0.1001 data_time: 0.0033 memory: 3302 loss: 1.6912 loss_cls: 1.0489 loss_bbox: 0.6424
2023/06/12 11:15:07 - mmengine - INFO - Epoch(train) [1][ 350/7330] lr: 6.9970e-03 eta: 3:04:33 time: 0.0978 data_time: 0.0033 memory: 3307 loss: 1.6070 loss_cls: 0.9797 loss_bbox: 0.6273
2023/06/12 11:15:12 - mmengine - INFO - Epoch(train) [1][ 400/7330] lr: 7.9980e-03 eta: 2:58:13 time: 0.0922 data_time: 0.0033 memory: 3303 loss: 1.7249 loss_cls: 1.1164 loss_bbox: 0.6085
2023/06/12 11:15:16 - mmengine - INFO - Epoch(train) [1][ 450/7330] lr: 8.9990e-03 eta: 2:53:47 time: 0.0954 data_time: 0.0033 memory: 3302 loss: 1.5828 loss_cls: 0.9846 loss_bbox: 0.5982
2023/06/12 11:15:21 - mmengine - INFO - Epoch(train) [1][ 500/7330] lr: 1.0000e-02 eta: 2:50:09 time: 0.0949 data_time: 0.0033 memory: 3304 loss: 1.4992 loss_cls: 0.9300 loss_bbox: 0.5692
2023/06/12 11:15:26 - mmengine - INFO - Epoch(train) [1][ 550/7330] lr: 1.0000e-02 eta: 2:47:14 time: 0.0954 data_time: 0.0032 memory: 3302 loss: 1.4713 loss_cls: 0.9419 loss_bbox: 0.5294 The memory allocated is optimized from 3300 to 2432 |
Motivation
This PR is motivated by the arxiv paper https://arxiv.org/abs/2305.11624 Tune-Mode ConvBN Blocks For Efficient Transfer Learning. It leverages the associative law between convolution and affine transform, i.e., normalize (weight conv feature) = (normalize weight) conv feature.
It has two advantages:
Modification
The implementation appears as a pre-forward hook registered on the conv layer. It is compatible with the existing implementation. During each forward calculation, it identifies whether the hook should be activated, and then switch to the fast computation.
BC-breaking (Optional)
This should not break any existing code.
Use cases (Optional)
There are two possible use cases:
Define post_build_model hook in MMCV, which is used by default. The hook traces the network (typically only the backbone) to replace consecutive conv and bn with the new ConvModule. This way, downstream users seamlessly enjoy the speedup.
Modify the
build_model
function for each downstream repo (like mmdetection and mmpose) to trace consecutive conv and bn, replacing them with a new ConvModule.Checklist
Before PR:
After PR: