-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Conversation
colorjam
commented
Oct 30, 2020
•
edited
Loading
edited
- Format each layer info for easy debug
- Solve multiple calls of one module
@colorjam please resolve conflict |
|
||
# Given input tensor with size (1, 1, 28, 28) and switch to full mode | ||
x = torch.randn(1, 1, 28, 28) | ||
flops, params, results = count_flops_params(model, (x, ), mode='full') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameter (x,)
seems not consistent with code. should be x
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks~ fixed it, please review the latest version~
"module_type": type(m).__name__ | ||
} | ||
|
||
add_results(m._name, **results) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about count_xxx
functions just return the results, and do the add_results
outside? it will make count_xxx
functions cleaner.
we can use a common hook for register_forward_hook
and call count_xxx
per module type inside the common hook
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I have updated the code, please review the latest version~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
countxxx
still calls self._push_results
, I mean define another common hook function for example myhook
to call countxxx
by module type, then do _push_results
in myhook
. Register my hook as the hook using something like m.register_forward_hook(myhook)
, in this way countxxx
is a pure util function with single responsibility and can be reused later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix it~
@@ -121,14 +121,19 @@ fixed_mask = fix_mask_conflict('./resnet18_mask', net, data) | |||
``` | |||
|
|||
## Model FLOPs/Parameters Counter | |||
We provide a model counter for calculating the model FLOPs and parameters. This counter supports calculating FLOPs/parameters of a normal model without masks, it can also calculates FLOPs/parameters of a model with mask wrappers, which helps users easily check model complexity during model compression on NNI. Note that, for sturctured pruning, we only identify the remained filters according to its mask, which not taking the pruned input channels into consideration, so the calculated FLOPs will be larger than real number (i.e., the number calculated after Model Speedup). | |||
We provide a model counter for calculating the model FLOPs and parameters. This counter supports calculating FLOPs/parameters of a normal model without masks, it can also calculates FLOPs/parameters of a model with mask wrappers, which helps users easily check model complexity during model compression on NNI. Note that, for sturctured pruning, we only identify the remained filters according to its mask, which not taking the pruned input channels into consideration, so the calculated FLOPs will be larger than real number (i.e., the number calculated after Model Speedup). We support two modes to collect information of modules. The first mode is `default`, which only collect the information of convolution and linear. The second mode is `full`, which also collect the information of other operations. Users can easily use our collected `results` for futher analysis. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use a new paragraph for your added content
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix it~
input_size: list, tuple | ||
the input shape of data | ||
x: tuple or tensor | ||
the input shape of data or a tensor as input data | ||
custom_ops: dict | ||
a mapping of (module: custom operation) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is not clear what is the type of "custom operation", it is a function? the key is module type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, user should manually define a function for counting flops of the module
|
||
Returns | ||
------- | ||
flops: float | ||
total flops of the model | ||
params: | ||
total params of the model | ||
results: dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only write type here, for example, dict
|
||
Returns | ||
------- | ||
flops: float | ||
total flops of the model | ||
params: | ||
total params of the model | ||
results: dict | ||
detail information of modules |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what kind of information?
|
||
device = next(model.parameters()).device | ||
inputs = torch.randn(input_size).to(device) | ||
if torch.is_tensor(x[0]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so here x
is a list?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, x are positional arguments given to the module
assert input_size is not None | ||
class ModelProfiler: | ||
# use a class to share state to hooks | ||
# profile results are available in `self.results` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please use standard docstring format
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a private implementation class. Not visible to user (not in __all__
) anyway.
Count FLOPs and Params of the given model. This function would identify the mask on the module | ||
and take the pruned shape into consideration. Note that, for sturctured pruning, we only identify | ||
the remained filters according to its mask, and do not take the pruned input channels into consideration, | ||
so the calculated FLOPs will be larger than real number. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add one more blank line here too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
the mode of how to collect information. If the mode is set to `default`, | ||
only the information of convolution and linear will be collected. | ||
If the mode is set to `full`, other operations will also be collected. | ||
Returns |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
|
||
if len(list(m.children())) == 0 and type(m) in profiler.ops: | ||
# if a leaf node | ||
_handler = m.register_forward_hook(functools.partial(profiler.counte_module, name=name)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
counte_module
-> count_module