Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Add several speedup examples (#3880)
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-ningxin authored Jul 15, 2021
1 parent 5fe2450 commit 47c7ea1
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torchvision import datasets, transforms

import sys
sys.path.append('../models')
sys.path.append('../../models')
from cifar10.vgg import VGG
from mnist.lenet import LeNet

Expand Down
21 changes: 21 additions & 0 deletions examples/model_compress/pruning/speedup/speedup_mobilnetv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch
from torchvision.models import mobilenet_v2
from nni.compression.pytorch import ModelSpeedup
from nni.algorithms.compression.pytorch.pruning import L1FilterPruner


model = mobilenet_v2(pretrained=True)
dummy_input = torch.rand(8, 3, 416, 416)

cfg_list = [{'op_types':['Conv2d'], 'sparsity':0.5}]
pruner = L1FilterPruner(model, cfg_list)
pruner.compress()
pruner.export_model('./model', './mask')
# need call _unwrap_model if you want run the speedup on the same model
pruner._unwrap_model()

# Speedup the nanodet
ms = ModelSpeedup(model, dummy_input, './mask')
ms.speedup_model()

model(dummy_input)
39 changes: 39 additions & 0 deletions examples/model_compress/pruning/speedup/speedup_nanodet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch
from nanodet.model.arch import build_model
from nanodet.util import cfg, load_config

from nni.compression.pytorch import ModelSpeedup
from nni.algorithms.compression.pytorch.pruning import L1FilterPruner

"""
NanoDet model can be installed from https://github.com/RangiLyu/nanodet.git
"""

cfg_path = r"nanodet/config/nanodet-RepVGG-A0_416.yml"
load_config(cfg, cfg_path)

model = build_model(cfg.model)
dummy_input = torch.rand(8, 3, 416, 416)

op_names = []
# these three conv layers are followed by reshape-like functions
# that cannot be replaced, so we skip these three conv layers,
# you can also get such layers by `not_safe_to_prune` function
excludes = ['head.gfl_cls.0', 'head.gfl_cls.1', 'head.gfl_cls.2']
for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
if name not in excludes:
op_names.append(name)

cfg_list = [{'op_types':['Conv2d'], 'sparsity':0.5, 'op_names':op_names}]
pruner = L1FilterPruner(model, cfg_list)
pruner.compress()
pruner.export_model('./model', './mask')
# need call _unwrap_model if you want run the speedup on the same model
pruner._unwrap_model()

# Speedup the nanodet
ms = ModelSpeedup(model, dummy_input, './mask')
ms.speedup_model()

model(dummy_input)
36 changes: 36 additions & 0 deletions examples/model_compress/pruning/speedup/speedup_yolov3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
from pytorchyolo import models

from nni.compression.pytorch import ModelSpeedup
from nni.algorithms.compression.pytorch.pruning import L1FilterPruner, LevelPruner
from nni.compression.pytorch.utils import not_safe_to_prune

# The Yolo can be downloaded at https://github.com/eriklindernoren/PyTorch-YOLOv3.git
prefix = '/home/user/PyTorch-YOLOv3' # replace this path with yours
# Load the YOLO model
model = models.load_model(
"%s/config/yolov3.cfg" % prefix,
"%s/yolov3.weights" % prefix)
model.eval()
dummy_input = torch.rand(8, 3, 320, 320)
model(dummy_input)
# Generate the config list for pruner
# Filter the layers that may not be able to prune
not_safe = not_safe_to_prune(model, dummy_input)
cfg_list = []
for name, module in model.named_modules():
if name in not_safe:
continue
if isinstance(module, torch.nn.Conv2d):
cfg_list.append({'op_types':['Conv2d'], 'sparsity':0.6, 'op_names':[name]})
# Prune the model
pruner = L1FilterPruner(model, cfg_list)
pruner.compress()
pruner.export_model('./model', './mask')
pruner._unwrap_model()
# Speedup the model
ms = ModelSpeedup(model, dummy_input, './mask')

ms.speedup_model()
model(dummy_input)

9 changes: 8 additions & 1 deletion nni/compression/pytorch/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,14 @@ def randomize_tensor(tensor, start=1, end=100):

def not_safe_to_prune(model, dummy_input):
"""
Get the layers that are safe to prune(will not bring the shape conflict).
Get the layers that are not safe to prune(may bring the shape conflict).
For example, if the output tensor of a conv layer is directly followed by
a shape-dependent function(such as reshape/view), then this conv layer
may be not safe to be pruned. Pruning may change the output shape of
this conv layer and result in shape problems. This function find all the
layers that directly followed by the shape-dependent functions(view, reshape, etc).
If you run the inference after the speedup and run into a shape related error,
please exclude the layers returned by this function and try again.
Parameters
----------
Expand Down
4 changes: 4 additions & 0 deletions test/ut/sdk/test_model_speedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,10 @@ def test_speedup_integration_big(self):
self.speedup_integration(model_list)

def speedup_integration(self, model_list, speedup_cfg=None):
# Note: hack trick, may be updated in the future
if 'win' in sys.platform or 'Win'in sys.platform:
print('Skip test_speedup_integration on windows due to memory limit!')
return
Gen_cfg_funcs = [generate_random_sparsity, generate_random_sparsity_v2]

# for model_name in ['vgg16', 'resnet18', 'mobilenet_v2', 'squeezenet1_1', 'densenet121',
Expand Down

0 comments on commit 47c7ea1

Please sign in to comment.