Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Refactor flops counter #3048

Merged
merged 38 commits into from
Nov 23, 2020
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
1817648
update title level
Oct 16, 2020
5bb7527
support layer info
Oct 30, 2020
b939e0c
update
Nov 2, 2020
d290d20
Merge branch 'master' into counter-dev
Nov 4, 2020
ea6050a
update
Nov 4, 2020
8c639c4
Merge branch 'master' of https://github.com/microsoft/nni
Nov 4, 2020
edf0938
Merge branch 'master' into counter-dev
Nov 4, 2020
828729a
Merge branch 'master' of https://github.com/microsoft/nni into counte…
Nov 4, 2020
f8e0d2d
delete files
Nov 4, 2020
136d274
delete files
Nov 4, 2020
6a4f520
format result
Nov 4, 2020
464adae
update
Nov 4, 2020
0ab3ed6
remove ut in counter
Nov 4, 2020
bb76bce
fix pipeline
colorjam Nov 5, 2020
d9a0a45
fix pipeline
Nov 5, 2020
5187d1c
fix pipeline
Nov 5, 2020
53e7b75
fix pipeline
Nov 5, 2020
615a461
Merge branch 'master' of https://github.com/microsoft/nni into counte…
colorjam Nov 10, 2020
bffdd59
update according to reviews comments
colorjam Nov 10, 2020
2f45c26
fix pipeline
colorjam Nov 11, 2020
d4483e4
update
colorjam Nov 11, 2020
40da63e
pretty docstring
Nov 12, 2020
5bcdead
add blank line
Nov 12, 2020
a39276e
remove whitespace
Nov 12, 2020
59cdf84
remove whitespace
Nov 12, 2020
852907d
update
Nov 12, 2020
7534117
add line
Nov 12, 2020
e41f2b7
Fix ut
colorjam Nov 13, 2020
3558fa1
Remove ut in counter
colorjam Nov 15, 2020
675d292
Update up
colorjam Nov 15, 2020
ba7770d
Use single module counter
colorjam Nov 17, 2020
78fae9d
Fix minus typo
colorjam Nov 17, 2020
9099310
Fix minus typo
colorjam Nov 17, 2020
7fd24a7
Add get result function
colorjam Nov 17, 2020
472f051
Fix docstring and a few other typos
ultmaster Nov 17, 2020
05eac2b
Fix pylint
colorjam Nov 18, 2020
2c3cd16
Fix pipeline
colorjam Nov 18, 2020
82ab8e0
Remove whitespace
colorjam Nov 18, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions docs/en_US/Compression/CompressionUtils.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,28 @@ fixed_mask = fix_mask_conflict('./resnet18_mask', net, data)
```

## Model FLOPs/Parameters Counter
We provide a model counter for calculating the model FLOPs and parameters. This counter supports calculating FLOPs/parameters of a normal model without masks, it can also calculates FLOPs/parameters of a model with mask wrappers, which helps users easily check model complexity during model compression on NNI. Note that, for sturctured pruning, we only identify the remained filters according to its mask, which not taking the pruned input channels into consideration, so the calculated FLOPs will be larger than real number (i.e., the number calculated after Model Speedup).
We provide a model counter for calculating the model FLOPs and parameters. This counter supports calculating FLOPs/parameters of a normal model without masks, it can also calculates FLOPs/parameters of a model with mask wrappers, which helps users easily check model complexity during model compression on NNI. Note that, for sturctured pruning, we only identify the remained filters according to its mask, which not taking the pruned input channels into consideration, so the calculated FLOPs will be larger than real number (i.e., the number calculated after Model Speedup).

We support two modes to collect information of modules. The first mode is `default`, which only collect the information of convolution and linear. The second mode is `full`, which also collect the information of other operations. Users can easily use our collected `results` for futher analysis.
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved

### Usage
```
from nni.compression.pytorch.utils.counter import count_flops_params

# Given input size (1, 1, 28, 28)
flops, params = count_flops_params(model, (1, 1, 28, 28))
# Given input size (1, 1, 28, 28)
flops, params, results = count_flops_params(model, (1, 1, 28, 28))

# Given input tensor with size (1, 1, 28, 28) and switch to full mode
x = torch.randn(1, 1, 28, 28)

flops, params, results = count_flops_params(model, (x,) mode='full') # tuple of tensor as input

# Format output size to M (i.e., 10^6)
print(f'FLOPs: {flops/1e6:.3f}M, Params: {params/1e6:.3f}M)
print(results)
{
'conv': {'flops': [60], 'params': [20], 'weight_size': [(5, 3, 1, 1)], 'input_size': [(1, 3, 2, 2)], 'output_size': [(1, 5, 2, 2)], 'module_type': ['Conv2d']},
'conv2': {'flops': [100], 'params': [30], 'weight_size': [(5, 5, 1, 1)], 'input_size': [(1, 5, 2, 2)], 'output_size': [(1, 5, 2, 2)], 'module_type': ['Conv2d']}
}

```
Loading