Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Refactor flops counter #3048

Merged
merged 38 commits into from
Nov 23, 2020
Merged

Refactor flops counter #3048

merged 38 commits into from
Nov 23, 2020

Conversation

colorjam
Copy link
Contributor

@colorjam colorjam commented Oct 30, 2020

  • Format each layer info for easy debug
  • Solve multiple calls of one module

@liuzhe-lz liuzhe-lz mentioned this pull request Oct 30, 2020
77 tasks
@QuanluZhang
Copy link
Contributor

@colorjam please resolve conflict

@colorjam colorjam changed the base branch from v2.0 to master November 4, 2020 06:35

# Given input tensor with size (1, 1, 28, 28) and switch to full mode
x = torch.randn(1, 1, 28, 28)
flops, params, results = count_flops_params(model, (x, ), mode='full')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parameter (x,) seems not consistent with code. should be x ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks~ fixed it, please review the latest version~

"module_type": type(m).__name__
}

add_results(m._name, **results)
Copy link
Contributor

@chicm-ms chicm-ms Nov 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about count_xxx functions just return the results, and do the add_results outside? it will make count_xxx functions cleaner.

we can use a common hook for register_forward_hook and call count_xxx per module type inside the common hook

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I have updated the code, please review the latest version~

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

countxxx still calls self._push_results, I mean define another common hook function for example myhook to call countxxx by module type, then do _push_results in myhook. Register my hook as the hook using something like m.register_forward_hook(myhook), in this way countxxx is a pure util function with single responsibility and can be reused later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix it~

@@ -121,14 +121,19 @@ fixed_mask = fix_mask_conflict('./resnet18_mask', net, data)
```

## Model FLOPs/Parameters Counter
We provide a model counter for calculating the model FLOPs and parameters. This counter supports calculating FLOPs/parameters of a normal model without masks, it can also calculates FLOPs/parameters of a model with mask wrappers, which helps users easily check model complexity during model compression on NNI. Note that, for sturctured pruning, we only identify the remained filters according to its mask, which not taking the pruned input channels into consideration, so the calculated FLOPs will be larger than real number (i.e., the number calculated after Model Speedup).
We provide a model counter for calculating the model FLOPs and parameters. This counter supports calculating FLOPs/parameters of a normal model without masks, it can also calculates FLOPs/parameters of a model with mask wrappers, which helps users easily check model complexity during model compression on NNI. Note that, for sturctured pruning, we only identify the remained filters according to its mask, which not taking the pruned input channels into consideration, so the calculated FLOPs will be larger than real number (i.e., the number calculated after Model Speedup). We support two modes to collect information of modules. The first mode is `default`, which only collect the information of convolution and linear. The second mode is `full`, which also collect the information of other operations. Users can easily use our collected `results` for futher analysis.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use a new paragraph for your added content

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix it~

input_size: list, tuple
the input shape of data
x: tuple or tensor
the input shape of data or a tensor as input data
custom_ops: dict
a mapping of (module: custom operation)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is not clear what is the type of "custom operation", it is a function? the key is module type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, user should manually define a function for counting flops of the module


Returns
-------
flops: float
total flops of the model
params:
total params of the model
results: dict
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only write type here, for example, dict


Returns
-------
flops: float
total flops of the model
params:
total params of the model
results: dict
detail information of modules
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what kind of information?


device = next(model.parameters()).device
inputs = torch.randn(input_size).to(device)
if torch.is_tensor(x[0]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so here x is a list?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, x are positional arguments given to the module

assert input_size is not None
class ModelProfiler:
# use a class to share state to hooks
# profile results are available in `self.results`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use standard docstring format

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a private implementation class. Not visible to user (not in __all__) anyway.

Count FLOPs and Params of the given model. This function would identify the mask on the module
and take the pruned shape into consideration. Note that, for sturctured pruning, we only identify
the remained filters according to its mask, and do not take the pruned input channels into consideration,
so the calculated FLOPs will be larger than real number.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add one more blank line here too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

the mode of how to collect information. If the mode is set to `default`,
only the information of convolution and linear will be collected.
If the mode is set to `full`, other operations will also be collected.
Returns
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added


if len(list(m.children())) == 0 and type(m) in profiler.ops:
# if a leaf node
_handler = m.register_forward_hook(functools.partial(profiler.counte_module, name=name))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

counte_module -> count_module

@liuzhe-lz liuzhe-lz closed this Nov 20, 2020
@liuzhe-lz liuzhe-lz reopened this Nov 20, 2020
@liuzhe-lz liuzhe-lz merged commit b6233e5 into microsoft:master Nov 23, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants