This repository has been archived by the owner on Sep 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add several speedup examples (#3880)
- Loading branch information
1 parent
5fe2450
commit 47c7ea1
Showing
6 changed files
with
109 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
21 changes: 21 additions & 0 deletions
21
examples/model_compress/pruning/speedup/speedup_mobilnetv2.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import torch | ||
from torchvision.models import mobilenet_v2 | ||
from nni.compression.pytorch import ModelSpeedup | ||
from nni.algorithms.compression.pytorch.pruning import L1FilterPruner | ||
|
||
|
||
model = mobilenet_v2(pretrained=True) | ||
dummy_input = torch.rand(8, 3, 416, 416) | ||
|
||
cfg_list = [{'op_types':['Conv2d'], 'sparsity':0.5}] | ||
pruner = L1FilterPruner(model, cfg_list) | ||
pruner.compress() | ||
pruner.export_model('./model', './mask') | ||
# need call _unwrap_model if you want run the speedup on the same model | ||
pruner._unwrap_model() | ||
|
||
# Speedup the nanodet | ||
ms = ModelSpeedup(model, dummy_input, './mask') | ||
ms.speedup_model() | ||
|
||
model(dummy_input) |
39 changes: 39 additions & 0 deletions
39
examples/model_compress/pruning/speedup/speedup_nanodet.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import torch | ||
from nanodet.model.arch import build_model | ||
from nanodet.util import cfg, load_config | ||
|
||
from nni.compression.pytorch import ModelSpeedup | ||
from nni.algorithms.compression.pytorch.pruning import L1FilterPruner | ||
|
||
""" | ||
NanoDet model can be installed from https://github.com/RangiLyu/nanodet.git | ||
""" | ||
|
||
cfg_path = r"nanodet/config/nanodet-RepVGG-A0_416.yml" | ||
load_config(cfg, cfg_path) | ||
|
||
model = build_model(cfg.model) | ||
dummy_input = torch.rand(8, 3, 416, 416) | ||
|
||
op_names = [] | ||
# these three conv layers are followed by reshape-like functions | ||
# that cannot be replaced, so we skip these three conv layers, | ||
# you can also get such layers by `not_safe_to_prune` function | ||
excludes = ['head.gfl_cls.0', 'head.gfl_cls.1', 'head.gfl_cls.2'] | ||
for name, module in model.named_modules(): | ||
if isinstance(module, torch.nn.Conv2d): | ||
if name not in excludes: | ||
op_names.append(name) | ||
|
||
cfg_list = [{'op_types':['Conv2d'], 'sparsity':0.5, 'op_names':op_names}] | ||
pruner = L1FilterPruner(model, cfg_list) | ||
pruner.compress() | ||
pruner.export_model('./model', './mask') | ||
# need call _unwrap_model if you want run the speedup on the same model | ||
pruner._unwrap_model() | ||
|
||
# Speedup the nanodet | ||
ms = ModelSpeedup(model, dummy_input, './mask') | ||
ms.speedup_model() | ||
|
||
model(dummy_input) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import torch | ||
from pytorchyolo import models | ||
|
||
from nni.compression.pytorch import ModelSpeedup | ||
from nni.algorithms.compression.pytorch.pruning import L1FilterPruner, LevelPruner | ||
from nni.compression.pytorch.utils import not_safe_to_prune | ||
|
||
# The Yolo can be downloaded at https://github.com/eriklindernoren/PyTorch-YOLOv3.git | ||
prefix = '/home/user/PyTorch-YOLOv3' # replace this path with yours | ||
# Load the YOLO model | ||
model = models.load_model( | ||
"%s/config/yolov3.cfg" % prefix, | ||
"%s/yolov3.weights" % prefix) | ||
model.eval() | ||
dummy_input = torch.rand(8, 3, 320, 320) | ||
model(dummy_input) | ||
# Generate the config list for pruner | ||
# Filter the layers that may not be able to prune | ||
not_safe = not_safe_to_prune(model, dummy_input) | ||
cfg_list = [] | ||
for name, module in model.named_modules(): | ||
if name in not_safe: | ||
continue | ||
if isinstance(module, torch.nn.Conv2d): | ||
cfg_list.append({'op_types':['Conv2d'], 'sparsity':0.6, 'op_names':[name]}) | ||
# Prune the model | ||
pruner = L1FilterPruner(model, cfg_list) | ||
pruner.compress() | ||
pruner.export_model('./model', './mask') | ||
pruner._unwrap_model() | ||
# Speedup the model | ||
ms = ModelSpeedup(model, dummy_input, './mask') | ||
|
||
ms.speedup_model() | ||
model(dummy_input) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters