Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] support to calculate FLOPs of GN, IN, LN #897

Merged
merged 3 commits into from
Mar 19, 2021

Conversation

zhouzaida
Copy link
Collaborator

@zhouzaida zhouzaida commented Mar 18, 2021

Related Issue:#886

Support:

  1. Support to calculate FLOPs of GroupNorm, InstanceNorm1d, InstanceNorm2d, InstanceNorm3d, LayerNorm

Discussion:
1. how to support torch.bmm
Now we only support to calculate FLOPs of those modules inherited from nn.Module. Therefore, operations which are not inherited from nn.Module are not supported, such as torch.bmm, torch.nn.functional.conv2d.

import torch
import torch.nn as nn
import torch.nn.functional as F

from mmcv.cnn.utils import get_model_complexity_info


class Dummy(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 8, 3)
        self.conv2 = nn.Conv2d(8, 256, 3)
        self.conv3 = nn.Conv2d(256, 8, 3)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(8, 1)
    def forward(self, x):
        # nn.Module
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.avg_pool(x)
        x = self.flatten(x)
        x = self.fc(x)
        # torch.nn.functional.conv1d is not supported
        filters = torch.randn(33, 16, 3)
        inputs = torch.randn(20, 16, 50)
        outputs = F.conv1d(inputs, filters)
        # torch.bmm is not supported
        inputs_1 = torch.randn(10, 3, 4)
        inputs_2 = torch.randn(10, 4, 5)
        outputs = torch.bmm(inputs_1, inputs_2)
        return outputs


get_model_complexity_info(Dummy(), (3, 16, 16))
"""
Dummy(
  0.037 M, 100.000% Params, 0.005 GFLOPs, 100.000% FLOPs, 
  (conv1): Conv2d(0.0 M, 0.600% Params, 0.0 GFLOPs, 0.959% FLOPs, 3, 8, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(0.019 M, 50.020% Params, 0.003 GFLOPs, 58.760% FLOPs, 8, 256, kernel_size=(3, 3), stride=(1, 1))
  (conv3): Conv2d(0.018 M, 49.356% Params, 0.002 GFLOPs, 40.264% FLOPs, 256, 8, kernel_size=(3, 3), stride=(1, 1))
  (avg_pool): AdaptiveAvgPool2d(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.017% FLOPs, output_size=(1, 1))
  (flatten): Flatten(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.000% FLOPs, )
  (fc): Linear(0.0 M, 0.024% Params, 0.0 GFLOPs, 0.000% FLOPs, in_features=8, out_features=1, bias=True)
)
"""

Maybe we can use decorator to support torch.bmm or torch.nn.functional.conv2d and so on.

from collections import defaultdict
import functools

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn.utils import get_model_complexity_info


def bmm_flops_count(input, output):
    input1, input2, *remain = input[0]
    return np.prod(input1.shape[1:]) * input2.shape[-1]


method_mapping = {
    'bmm': bmm_flops_count,
}
flops_cnt = defaultdict(int)


def flops_count_wrapper(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        output = func(*args, **kwargs)
        name = func.__name__
        flops_cnt[name] += method_mapping[name](input=(args, kwargs),
                                                output=output)
        return output
    return wrapper


# decorate wrapper
torch.bmm = flops_count_wrapper(torch.bmm)


class Dummy(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 8, 3)
        self.conv2 = nn.Conv2d(8, 256, 3)
        self.conv3 = nn.Conv2d(256, 8, 3)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(8, 1)
    def forward(self, x):
        # nn.Module
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.avg_pool(x)
        x = self.flatten(x)
        x = self.fc(x)
        # torch.nn.functional.conv1d
        filters = torch.randn(33, 16, 3)
        inputs = torch.randn(20, 16, 50)
        outputs = F.conv1d(inputs, filters)
        # torch.bmm
        inputs_1 = torch.randn(10, 3, 4)
        inputs_2 = torch.randn(10, 4, 5)
        outputs = torch.bmm(inputs_1, inputs_2)
        return outputs


get_model_complexity_info(Dummy(), (3, 16, 16))
print(flops_cnt)
"""
Dummy(
  0.037 M, 100.000% Params, 0.005 GFLOPs, 100.000% FLOPs, 
  (conv1): Conv2d(0.0 M, 0.600% Params, 0.0 GFLOPs, 0.959% FLOPs, 3, 8, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(0.019 M, 50.020% Params, 0.003 GFLOPs, 58.760% FLOPs, 8, 256, kernel_size=(3, 3), stride=(1, 1))
  (conv3): Conv2d(0.018 M, 49.356% Params, 0.002 GFLOPs, 40.264% FLOPs, 256, 8, kernel_size=(3, 3), stride=(1, 1))
  (avg_pool): AdaptiveAvgPool2d(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.017% FLOPs, output_size=(1, 1))
  (flatten): Flatten(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.000% FLOPs, )
  (fc): Linear(0.0 M, 0.024% Params, 0.0 GFLOPs, 0.000% FLOPs, in_features=8, out_features=1, bias=True)
)
defaultdict(<class 'int'>, {'bmm': 60})
"""

2. deconv_flops_counter_hook and conv_flops_counter_hook
Should deconv_flops_counter_hook and
conv_flops_counter_hook be the same?

Copy link
Collaborator

@MeowZheng MeowZheng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@codecov
Copy link

codecov bot commented Mar 18, 2021

Codecov Report

Merging #897 (cf0b651) into master (73bff4e) will increase coverage by 0.01%.
The diff coverage is 76.47%.

❗ Current head cf0b651 differs from pull request most recent head e9a8f3c. Consider uploading reports for the commit e9a8f3c to get more accurate results
Impacted file tree graph

@@            Coverage Diff             @@
##           master     #897      +/-   ##
==========================================
+ Coverage   66.58%   66.59%   +0.01%     
==========================================
  Files         145      145              
  Lines        8828     8841      +13     
  Branches     1605     1606       +1     
==========================================
+ Hits         5878     5888      +10     
- Misses       2633     2637       +4     
+ Partials      317      316       -1     
Flag Coverage Δ
unittests 66.59% <76.47%> (+0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
mmcv/runner/base_module.py 79.41% <71.42%> (-6.31%) ⬇️
mmcv/cnn/utils/flops_counter.py 93.63% <100.00%> (+0.45%) ⬆️
mmcv/runner/__init__.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 73bff4e...e9a8f3c. Read the comment docs.

@hellock hellock merged commit 97730c2 into open-mmlab:master Mar 19, 2021
@zhouzaida zhouzaida deleted the flop_cnt branch March 19, 2021 06:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants