Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] support to calculate FLOPs of GN, IN, LN #897

Merged
merged 3 commits into from
Mar 19, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 18 additions & 29 deletions mmcv/cnn/utils/flops_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,11 @@ def get_model_complexity_info(model,
input_constructor=None,
flush=False,
ost=sys.stdout):
"""Get complexity information of a model.
"""Get complexity information of a model. This method can calculate FLOPs
and parameter counts of a model with corresponding input shape. It can also
print complexity information for each layer in a model. Supported layers
are listed as below:

This method can calculate FLOPs and parameter counts of a model with
corresponding input shape. It can also print complexity information for
each layer in a model.

Supported layers are listed as below:
- Convolutions: ``nn.Conv1d``, ``nn.Conv2d``, ``nn.Conv3d``.
- Activations: ``nn.ReLU``, ``nn.PReLU``, ``nn.ELU``, ``nn.LeakyReLU``,
``nn.ReLU6``.
Expand All @@ -56,11 +54,11 @@ def get_model_complexity_info(model,
``nn.AdaptiveMaxPool3d``, ``nn.AdaptiveAvgPool1d``,
``nn.AdaptiveAvgPool2d``, ``nn.AdaptiveAvgPool3d``.
- BatchNorms: ``nn.BatchNorm1d``, ``nn.BatchNorm2d``,
``nn.BatchNorm3d``.
``nn.BatchNorm3d``, ``nn.GroupNorm``, ``nn.InstanceNorm1d``,
``InstanceNorm2d``, ``InstanceNorm3d``, ``nn.LayerNorm``.
- Linear: ``nn.Linear``.
- Deconvolution: ``nn.ConvTranspose2d``.
- Upsample: ``nn.Upsample``.

zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
Args:
model (nn.Module): The model for complexity calculation.
input_shape (tuple): Input shape used for calculation.
Expand All @@ -74,7 +72,6 @@ def get_model_complexity_info(model,
flush (bool): same as that in :func:`print`. Default: False.
ost (stream): same as ``file`` param in :func:`print`.
Default: sys.stdout.

Returns:
tuple[float | str]: If ``as_strings`` is set to True, it will return
FLOPs and parameter counts in a string format. otherwise, it will
Expand Down Expand Up @@ -116,19 +113,15 @@ def get_model_complexity_info(model,

def flops_to_string(flops, units='GFLOPs', precision=2):
"""Convert FLOPs number into a string.

Note that Here we take a multiply-add counts as one FLOP.

Args:
flops (float): FLOPs number to be converted.
units (str | None): Converted FLOPs units. Options are None, 'GFLOPs',
'MFLOPs', 'KFLOPs', 'FLOPs'. If set to None, it will automatically
choose the most suitable unit for FLOPs. Default: 'GFLOPs'.
precision (int): Digit number after the decimal point. Default: 2.

Returns:
str: The converted FLOPs number with units.

Examples:
>>> flops_to_string(1e9)
'1.0 GFLOPs'
Expand Down Expand Up @@ -159,17 +152,14 @@ def flops_to_string(flops, units='GFLOPs', precision=2):

def params_to_string(num_params, units=None, precision=2):
"""Convert parameter number into a string.

Args:
num_params (float): Parameter number to be converted.
units (str | None): Converted FLOPs units. Options are None, 'M',
'K' and ''. If set to None, it will automatically choose the most
suitable unit for Parameter number. Default: None.
precision (int): Digit number after the decimal point. Default: 2.

Returns:
str: The converted parameter number with units.

Examples:
>>> params_to_string(1e9)
'1000.0 M'
Expand Down Expand Up @@ -202,7 +192,6 @@ def print_model_with_flops(model,
ost=sys.stdout,
flush=False):
"""Print a model with FLOPs for each layer.

Args:
model (nn.Module): The model to be printed.
total_flops (float): Total FLOPs of the model.
Expand All @@ -212,10 +201,8 @@ def print_model_with_flops(model,
ost (stream): same as `file` param in :func:`print`.
Default: sys.stdout.
flush (bool): same as that in :func:`print`. Default: False.

Example:
>>> class ExampleModel(nn.Module):

>>> def __init__(self):
>>> super().__init__()
>>> self.conv1 = nn.Conv2d(3, 8, 3)
Expand All @@ -224,7 +211,6 @@ def print_model_with_flops(model,
>>> self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
>>> self.flatten = nn.Flatten()
>>> self.fc = nn.Linear(8, 1)

>>> def forward(self, x):
>>> x = self.conv1(x)
>>> x = self.conv2(x)
Expand All @@ -233,7 +219,6 @@ def print_model_with_flops(model,
>>> x = self.flatten(x)
>>> x = self.fc(x)
>>> return x

>>> model = ExampleModel()
>>> x = (3, 16, 16)
to print the complexity inforamtion state for each layer, you can use
Expand Down Expand Up @@ -308,7 +293,6 @@ def get_model_parameters_number(model):

Args:
model (nn.module): The model for parameter number calculation.

Returns:
float: Parameter number of the model.
"""
Expand Down Expand Up @@ -338,7 +322,6 @@ def compute_average_flops_cost(self):

A method to compute average FLOPs cost, which will be available after
`add_flops_counting_methods()` is called on a desired net object.

Returns:
float: Current mean flops consumption per image.
"""
Expand Down Expand Up @@ -426,11 +409,12 @@ def pool_flops_counter_hook(module, input, output):
module.__flops__ += int(np.prod(input.shape))


def bn_flops_counter_hook(module, input, output):
def norm_flops_counter_hook(module, input, output):
input = input[0]

batch_flops = np.prod(input.shape)
if module.affine:
if (getattr(module, 'affine', False)
or getattr(module, 'elementwise_affine', False)):
batch_flops *= 2
module.__flops__ += int(batch_flops)

Expand Down Expand Up @@ -577,10 +561,15 @@ def get_modules_mapping():
nn.AdaptiveAvgPool2d: pool_flops_counter_hook,
nn.AdaptiveMaxPool3d: pool_flops_counter_hook,
nn.AdaptiveAvgPool3d: pool_flops_counter_hook,
# BNs
nn.BatchNorm1d: bn_flops_counter_hook,
nn.BatchNorm2d: bn_flops_counter_hook,
nn.BatchNorm3d: bn_flops_counter_hook,
# normalizations
nn.BatchNorm1d: norm_flops_counter_hook,
nn.BatchNorm2d: norm_flops_counter_hook,
nn.BatchNorm3d: norm_flops_counter_hook,
nn.GroupNorm: norm_flops_counter_hook,
nn.InstanceNorm1d: norm_flops_counter_hook,
nn.InstanceNorm2d: norm_flops_counter_hook,
nn.InstanceNorm3d: norm_flops_counter_hook,
nn.LayerNorm: norm_flops_counter_hook,
# FC
nn.Linear: linear_flops_counter_hook,
mmcv.cnn.bricks.Linear: linear_flops_counter_hook,
Expand Down
12 changes: 9 additions & 3 deletions tests/test_cnn/test_flops_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,15 @@
{'model': nn.AdaptiveAvgPool1d(2), 'input': (3, 16), 'flops': 48.0, 'params': 0}, # noqa: E501
{'model': nn.AdaptiveAvgPool2d(2), 'input': (3, 16, 16), 'flops': 768.0, 'params': 0}, # noqa: E501
{'model': nn.AdaptiveAvgPool3d(2), 'input': (3, 3, 16, 16), 'flops': 2304.0, 'params': 0}, # noqa: E501
{'model': nn.BatchNorm1d(3, 8), 'input': (3, 16), 'flops': 96.0, 'params': 6.0}, # noqa: E501
{'model': nn.BatchNorm2d(3, 8), 'input': (3, 16, 16), 'flops': 1536.0, 'params': 6.0}, # noqa: E501
{'model': nn.BatchNorm3d(3, 8), 'input': (3, 3, 16, 16), 'flops': 4608.0, 'params': 6.0}, # noqa: E501
{'model': nn.BatchNorm1d(3), 'input': (3, 16), 'flops': 96.0, 'params': 6.0}, # noqa: E501
{'model': nn.BatchNorm2d(3), 'input': (3, 16, 16), 'flops': 1536.0, 'params': 6.0}, # noqa: E501
{'model': nn.BatchNorm3d(3), 'input': (3, 3, 16, 16), 'flops': 4608.0, 'params': 6.0}, # noqa: E501
{'model': nn.GroupNorm(2, 6), 'input': (6, 16, 16), 'flops': 3072.0, 'params': 12.0}, # noqa: E501
{'model': nn.InstanceNorm1d(3, affine=True), 'input': (3, 16), 'flops': 96.0, 'params': 6.0}, # noqa: E501
{'model': nn.InstanceNorm2d(3, affine=True), 'input': (3, 16, 16), 'flops': 1536.0, 'params': 6.0}, # noqa: E501
{'model': nn.InstanceNorm3d(3, affine=True), 'input': (3, 3, 16, 16), 'flops': 4608.0, 'params': 6.0}, # noqa: E501
{'model': nn.LayerNorm((3, 16, 16)), 'input': (3, 16, 16), 'flops': 1536.0, 'params': 1536.0}, # noqa: E501
{'model': nn.LayerNorm((3, 16, 16), elementwise_affine=False), 'input': (3, 16, 16), 'flops': 768.0, 'params': 0}, # noqa: E501
{'model': nn.Linear(1024, 2), 'input': (1024, ), 'flops': 2048.0, 'params': 2050.0}, # noqa: E501
{'model': nn.ConvTranspose2d(3, 8, 3), 'input': (3, 16, 16), 'flops': 57888, 'params': 224.0}, # noqa: E501
{'model': nn.Upsample((32, 32)), 'input': (3, 16, 16), 'flops': 3072.0, 'params': 0} # noqa: E501
Expand Down