This repository has been archived by the owner on Sep 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Add several speedup examples #3880
Merged
zheng-ningxin
merged 7 commits into
microsoft:master
from
zheng-ningxin:speedup_example
Jul 15, 2021
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
1751fc5
examples
9a18c7d
update
0bba07c
Merge branch 'master' of https://github.com/microsoft/nni into speedu…
7eefd88
update docstring
e082aef
skip the speedup integration test on windows, too slow
80fb98e
fix sys.path
664aeca
update
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
21 changes: 21 additions & 0 deletions
21
examples/model_compress/pruning/speedup/speedup_mobilnetv2.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import torch | ||
from torchvision.models import mobilenet_v2 | ||
from nni.compression.pytorch import ModelSpeedup | ||
from nni.algorithms.compression.pytorch.pruning import L1FilterPruner | ||
|
||
|
||
model = mobilenet_v2(pretrained=True) | ||
dummy_input = torch.rand(8, 3, 416, 416) | ||
|
||
cfg_list = [{'op_types':['Conv2d'], 'sparsity':0.5}] | ||
pruner = L1FilterPruner(model, cfg_list) | ||
pruner.compress() | ||
pruner.export_model('./model', './mask') | ||
# need call _unwrap_model if you want run the speedup on the same model | ||
pruner._unwrap_model() | ||
|
||
# Speedup the nanodet | ||
ms = ModelSpeedup(model, dummy_input, './mask') | ||
ms.speedup_model() | ||
|
||
model(dummy_input) |
39 changes: 39 additions & 0 deletions
39
examples/model_compress/pruning/speedup/speedup_nanodet.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import torch | ||
from nanodet.model.arch import build_model | ||
from nanodet.util import cfg, load_config | ||
|
||
from nni.compression.pytorch import ModelSpeedup | ||
from nni.algorithms.compression.pytorch.pruning import L1FilterPruner | ||
|
||
""" | ||
NanoDet model can be installed from https://github.com/RangiLyu/nanodet.git | ||
""" | ||
|
||
cfg_path = r"nanodet/config/nanodet-RepVGG-A0_416.yml" | ||
load_config(cfg, cfg_path) | ||
|
||
model = build_model(cfg.model) | ||
dummy_input = torch.rand(8, 3, 416, 416) | ||
|
||
op_names = [] | ||
# these three conv layers are followed by reshape-like functions | ||
# that cannot be replaced, so we skip these three conv layers, | ||
# you can also get such layers by `not_safe_to_prune` function | ||
excludes = ['head.gfl_cls.0', 'head.gfl_cls.1', 'head.gfl_cls.2'] | ||
for name, module in model.named_modules(): | ||
if isinstance(module, torch.nn.Conv2d): | ||
if name not in excludes: | ||
op_names.append(name) | ||
|
||
cfg_list = [{'op_types':['Conv2d'], 'sparsity':0.5, 'op_names':op_names}] | ||
pruner = L1FilterPruner(model, cfg_list) | ||
pruner.compress() | ||
pruner.export_model('./model', './mask') | ||
# need call _unwrap_model if you want run the speedup on the same model | ||
pruner._unwrap_model() | ||
|
||
# Speedup the nanodet | ||
ms = ModelSpeedup(model, dummy_input, './mask') | ||
ms.speedup_model() | ||
|
||
model(dummy_input) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import torch | ||
from pytorchyolo import models | ||
|
||
from nni.compression.pytorch import ModelSpeedup | ||
from nni.algorithms.compression.pytorch.pruning import L1FilterPruner, LevelPruner | ||
from nni.compression.pytorch.utils import not_safe_to_prune | ||
QuanluZhang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# The Yolo can be downloaded at https://github.com/eriklindernoren/PyTorch-YOLOv3.git | ||
prefix = '/home/user/PyTorch-YOLOv3' # replace this path with yours | ||
# Load the YOLO model | ||
model = models.load_model( | ||
"%s/config/yolov3.cfg" % prefix, | ||
"%s/yolov3.weights" % prefix) | ||
model.eval() | ||
dummy_input = torch.rand(8, 3, 320, 320) | ||
model(dummy_input) | ||
# Generate the config list for pruner | ||
# Filter the layers that may not be able to prune | ||
not_safe = not_safe_to_prune(model, dummy_input) | ||
cfg_list = [] | ||
for name, module in model.named_modules(): | ||
if name in not_safe: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe we can leverage "exclude" here in the config, @J-shang There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, |
||
continue | ||
if isinstance(module, torch.nn.Conv2d): | ||
cfg_list.append({'op_types':['Conv2d'], 'sparsity':0.6, 'op_names':[name]}) | ||
# Prune the model | ||
pruner = L1FilterPruner(model, cfg_list) | ||
pruner.compress() | ||
pruner.export_model('./model', './mask') | ||
pruner._unwrap_model() | ||
# Speedup the model | ||
ms = ModelSpeedup(model, dummy_input, './mask') | ||
|
||
ms.speedup_model() | ||
model(dummy_input) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is a little strange that user has to install model from other repo for running example. If this model is not very complicated, can we add it into our model compression model files so that user can run it directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to put their code into our repo. but we can provide concrete commands about how to prepare the code in the comment