-
Notifications
You must be signed in to change notification settings - Fork 360
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Docs] Translate "Model Complexity Analysis" to Chinese (#969)
* [Doc] Translate model complexity analysis into Chinese. * [Doc] Translate model complexity analysis into Chinese. * [Docs] fix the description of the interface * update introduction Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> * Update description of FLOPs Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> * Update activation Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> * Update model description Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> * Beautify code style Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> * Modify examples Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> * Upadate output description Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> * Update docs/zh_cn/advanced_tutorials/model_analysis.md Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> * Replace FLOPs with flop; fix typo * Fix typo * fix lint error * Update docs/zh_cn/advanced_tutorials/model_analysis.md Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update docs/zh_cn/advanced_tutorials/model_analysis.md Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update docs/zh_cn/advanced_tutorials/model_analysis.md Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update docs/zh_cn/advanced_tutorials/model_analysis.md Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update docs/zh_cn/advanced_tutorials/model_analysis.md Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update docs/zh_cn/advanced_tutorials/model_analysis.md Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update model_analysis.md * Update model_analysis.md * Apply suggestions from code review Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> --------- Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
- Loading branch information
1 parent
8063d2c
commit 330985d
Showing
2 changed files
with
263 additions
and
70 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,184 @@ | ||
# 模型复杂度分析 | ||
|
||
翻译中,请暂时阅读英文文档 [Model Complexity Analysis](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/model_analysis.html)。 | ||
我们提供了一个工具来帮助分析网络的复杂性。我们借鉴了 [fvcore](https://github.com/facebookresearch/fvcore) 的实现思路来构建这个工具,并计划在未来支持更多的自定义算子。目前的工具提供了用于计算给定模型的浮点运算量(FLOPs)、激活量(Activations)和参数量(Parameters)的接口,并支持以网络结构或表格的形式逐层打印相关信息,同时提供了算子级别(operator)和模块级别(Module)的统计。如果您对统计浮点运算量的实现细节感兴趣,请参考 [Flop Count](https://github.com/facebookresearch/fvcore/blob/main/docs/flop_count.md)。 | ||
|
||
## 定义 | ||
|
||
模型复杂度有 3 个指标,分别是浮点运算量(FLOPs)、激活量(Activations)以及参数量(Parameters),它们的定义如下: | ||
|
||
- 浮点运算量 | ||
|
||
浮点运算量不是一个定义非常明确的指标,在这里参考 [detectron2](https://detectron2.readthedocs.io/en/latest/modules/fvcore.html#fvcore.nn.FlopCountAnalysis) 的描述,将一组乘加运算定义为 1 个 flop。 | ||
|
||
- 激活量 | ||
|
||
激活量用于衡量某一层产生的特征数量。 | ||
|
||
- 参数量 | ||
|
||
模型的参数量。 | ||
|
||
例如,给定输入尺寸 `inputs = torch.randn((1, 3, 10, 10))`,和一个卷积层 `conv = nn.Conv2d(in_channels=3, out_channels=10, kernel_size=3)`,那么它输出的特征图尺寸为 `(1, 10, 8, 8)`,则它的浮点运算量是 `17280 = 10*8*8*3*3*3`(10*8*8 表示输出的特征图大小、3*3*3 表示每一个输出需要的计算量)、激活量是 `640 = 10*8*8`、参数量是 `280 = 3*10*3*3 + 10`(3*10*3\*3 表示权重的尺寸、10 表示偏置值的尺寸)。 | ||
|
||
## 用法 | ||
|
||
### 基于 `nn.Module` 构建的模型 | ||
|
||
构建模型 | ||
|
||
```python | ||
from torch import nn | ||
|
||
from mmengine.analysis import get_model_complexity_info | ||
|
||
|
||
# 以字典的形式返回分析结果,包括: | ||
# ['flops', 'flops_str', 'activations', 'activations_str', 'params', 'params_str', 'out_table', 'out_arch'] | ||
class InnerNet(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.fc1 = nn.Linear(10, 10) | ||
self.fc2 = nn.Linear(10, 10) | ||
|
||
def forward(self, x): | ||
return self.fc1(self.fc2(x)) | ||
|
||
|
||
class TestNet(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.fc1 = nn.Linear(10, 10) | ||
self.fc2 = nn.Linear(10, 10) | ||
self.inner = InnerNet() | ||
|
||
def forward(self, x): | ||
return self.fc1(self.fc2(self.inner(x))) | ||
|
||
|
||
input_shape = (1, 10) | ||
model = TestNet() | ||
``` | ||
|
||
`get_model_complexity_info` 返回的 `analysis_results` 是一个包含 7 个值的字典: | ||
|
||
- `flops`: flop 的总数, 例如, 1000, 1000000 | ||
- `flops_str`: 格式化的字符串, 例如, 1.0G, 1.0M | ||
- `params`: 全部参数的数量, 例如, 1000, 1000000 | ||
- `params_str`: 格式化的字符串, 例如, 1.0K, 1M | ||
- `activations`: 激活量的总数, 例如, 1000, 1000000 | ||
- `activations_str`: 格式化的字符串, 例如, 1.0G, 1M | ||
- `out_table`: 以表格形式打印相关信息 | ||
|
||
打印结果 | ||
|
||
- 以表格形式打印相关信息 | ||
|
||
```python | ||
print(analysis_results['out_table']) | ||
``` | ||
|
||
```text | ||
+---------------------+----------------------+--------+--------------+ | ||
| module | #parameters or shape | #flops | #activations | | ||
+---------------------+----------------------+--------+--------------+ | ||
| model | 0.44K | 0.4K | 40 | | ||
| fc1 | 0.11K | 100 | 10 | | ||
| fc1.weight | (10, 10) | | | | ||
| fc1.bias | (10,) | | | | ||
| fc2 | 0.11K | 100 | 10 | | ||
| fc2.weight | (10, 10) | | | | ||
| fc2.bias | (10,) | | | | ||
| inner | 0.22K | 0.2K | 20 | | ||
| inner.fc1 | 0.11K | 100 | 10 | | ||
| inner.fc1.weight | (10, 10) | | | | ||
| inner.fc1.bias | (10,) | | | | ||
| inner.fc2 | 0.11K | 100 | 10 | | ||
| inner.fc2.weight | (10, 10) | | | | ||
| inner.fc2.bias | (10,) | | | | ||
+---------------------+----------------------+--------+--------------+ | ||
``` | ||
|
||
- 以网络层级结构打印相关信息 | ||
|
||
```python | ||
print(analysis_results['out_arch']) | ||
``` | ||
|
||
```bash | ||
TestNet( | ||
#params: 0.44K, #flops: 0.4K, #acts: 40 | ||
(fc1): Linear( | ||
in_features=10, out_features=10, bias=True | ||
#params: 0.11K, #flops: 100, #acts: 10 | ||
) | ||
(fc2): Linear( | ||
in_features=10, out_features=10, bias=True | ||
#params: 0.11K, #flops: 100, #acts: 10 | ||
) | ||
(inner): InnerNet( | ||
#params: 0.22K, #flops: 0.2K, #acts: 20 | ||
(fc1): Linear( | ||
in_features=10, out_features=10, bias=True | ||
#params: 0.11K, #flops: 100, #acts: 10 | ||
) | ||
(fc2): Linear( | ||
in_features=10, out_features=10, bias=True | ||
#params: 0.11K, #flops: 100, #acts: 10 | ||
) | ||
) | ||
) | ||
``` | ||
|
||
- 以字符串的形式打印结果 | ||
|
||
```python | ||
print("Model Flops:{}".format(analysis_results['flops_str'])) | ||
# Model Flops:0.4K | ||
print("Model Parameters:{}".format(analysis_results['params_str'])) | ||
# Model Parameters:0.44K | ||
``` | ||
|
||
### 基于 BaseModel(来自 MMEngine)构建的模型 | ||
|
||
```python | ||
import torch.nn.functional as F | ||
import torchvision | ||
from mmengine.model import BaseModel | ||
from mmengine.analysis import get_model_complexity_info | ||
|
||
|
||
class MMResNet50(BaseModel): | ||
def __init__(self): | ||
super().__init__() | ||
self.resnet = torchvision.models.resnet50() | ||
|
||
def forward(self, imgs, labels=None, mode='tensor'): | ||
x = self.resnet(imgs) | ||
if mode == 'loss': | ||
return {'loss': F.cross_entropy(x, labels)} | ||
elif mode == 'predict': | ||
return x, labels | ||
elif mode == 'tensor': | ||
return x | ||
|
||
|
||
input_shape = (3, 224, 224) | ||
model = MMResNet50() | ||
|
||
analysis_results = get_model_complexity_info(model, input_shape) | ||
|
||
print("Model Flops:{}".format(analysis_results['flops_str'])) | ||
# Model Flops:4.145G | ||
print("Model Parameters:{}".format(analysis_results['params_str'])) | ||
# Model Parameters:25.557M | ||
``` | ||
|
||
## 其他接口 | ||
|
||
除了上述基本用法,`get_model_complexity_info` 还能接受以下参数,输出定制化的统计结果: | ||
|
||
- `model`: (nn.Module) 待分析的模型 | ||
- `input_shape`: (tuple) 输入尺寸,例如 (3, 224, 224) | ||
- `inputs`: (optional: torch.Tensor), 如果传入该参数, `input_shape` 会被忽略 | ||
- `show_table`: (bool) 是否以表格形式返回统计结果,默认值:True | ||
- `show_arch`: (bool) 是否以网络结构形式返回统计结果,默认值:True |