Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

refactor of modelspeedup example #2161

Merged
merged 11 commits into from
Mar 24, 2020
2 changes: 1 addition & 1 deletion examples/model_compress/model_prune_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def forward(self, x):
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
Expand Down
193 changes: 69 additions & 124 deletions examples/model_compress/model_speedup.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import argparse
import time
import torch
Expand All @@ -9,145 +10,89 @@
from nni.compression.torch import apply_compression_results

torch.manual_seed(0)
use_mask = False
use_mask = True
use_speedup = True
compare_results = True

def apoz_speedup(masks_file, model_checkpoint):
device = torch.device('cuda')
model = VGG(depth=16)
model.to(device)
model.eval()

dummy_input = torch.randn(64, 3, 32, 32)
if use_mask:
apply_compression_results(model, masks_file)
dummy_input = dummy_input.to(device)
start = time.time()
for _ in range(32):
out = model(dummy_input)
#print(out.size(), out)
print('mask elapsed time: ', time.time() - start)
return
else:
#print("model before: ", model)
m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file)
m_speedup.speedup_model()
#print("model after: ", model)
dummy_input = dummy_input.to(device)
start = time.time()
for _ in range(32):
out = model(dummy_input)
#print(out.size(), out)
print('speedup elapsed time: ', time.time() - start)
return
config = {
'apoz': {
'model_name': 'vgg16',
'device': 'cuda',
'input_shape': [64, 3, 32, 32],
'masks_file': './checkpoints/mask_vgg16_cifar10_apoz.pth'
},
'l1filter': {
'model_name': 'vgg16',
'device': 'cuda',
'input_shape': [64, 3, 32, 32],
'masks_file': './checkpoints/mask_vgg16_cifar10_l1.pth'
},
'fpgm': {
'model_name': 'naive',
'device': 'cpu',
'input_shape': [64, 1, 28, 28],
'masks_file': './checkpoints/mask_naive_mnist_fpgm.pth'
},
'slim': {
'model_name': 'vgg19',
'device': 'cuda',
'input_shape': [64, 3, 32, 32],
'masks_file': './checkpoints/mask_vgg19_cifar10_slim.pth' #'mask_vgg19_cifar10.pth'
}
}

def l1filter_speedup(masks_file, model_checkpoint):
device = torch.device('cuda')
model = VGG(depth=16)
def model_inference(config):
masks_file = config['masks_file']
device = torch.device(config['device'])
if config['model_name'] == 'vgg16':
model = VGG(depth=16)
elif config['model_name'] == 'vgg19':
model = VGG(depth=19)
elif config['model_name'] == 'naive':
from model_prune_torch import NaiveModel
model = NaiveModel()
model.to(device)
model.eval()

dummy_input = torch.randn(64, 3, 32, 32)
dummy_input = torch.randn(config['input_shape']).to(device)
use_mask_out = use_speedup_out = None
# must run use_mask before use_speedup because use_speedup modify the model
if use_mask:
apply_compression_results(model, masks_file)
dummy_input = dummy_input.to(device)
apply_compression_results(model, masks_file, 'cpu' if config['device'] == 'cpu' else None)
start = time.time()
for _ in range(32):
out = model(dummy_input)
#print(out.size(), out)
print('mask elapsed time: ', time.time() - start)
return
else:
#print("model before: ", model)
m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file)
use_mask_out = model(dummy_input)
print('elapsed time when use mask: ', time.time() - start)
if use_speedup:
m_speedup = ModelSpeedup(model, dummy_input, masks_file,
'cpu' if config['device'] == 'cpu' else None)
m_speedup.speedup_model()
#print("model after: ", model)
dummy_input = dummy_input.to(device)
start = time.time()
for _ in range(32):
out = model(dummy_input)
#print(out.size(), out)
print('speedup elapsed time: ', time.time() - start)
return

def fpgm_speedup(masks_file, model_checkpoint):
from fpgm_torch_mnist import Mnist
device = torch.device('cpu')
model = Mnist()
model.to(device)
model.print_conv_filter_sparsity()

dummy_input = torch.randn(64, 1, 28, 28)
if use_mask:
apply_compression_results(model, masks_file)
dummy_input = dummy_input.to(device)
start = time.time()
for _ in range(40):
out = model(dummy_input)
print('mask elapsed time: ', time.time() - start)
#print(out.size(), out)
return
else:
m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file)
m_speedup.speedup_model()
dummy_input = dummy_input.to(device)
start = time.time()
for _ in range(40):
out = model(dummy_input)
print('speedup elapsed time: ', time.time() - start)
#print(out.size(), out)
return

def slim_speedup(masks_file, model_checkpoint):
device = torch.device('cuda')
model = VGG(depth=19)
model.to(device)
model.eval()

dummy_input = torch.randn(64, 3, 32, 32)
if use_mask:
apply_compression_results(model, masks_file)
dummy_input = dummy_input.to(device)
start = time.time()
for _ in range(32):
out = model(dummy_input)
#print(out.size(), out)
print('mask elapsed time: ', time.time() - start)
return
else:
#print("model before: ", model)
m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file)
m_speedup.speedup_model()
#print("model after: ", model)
dummy_input = dummy_input.to(device)
start = time.time()
for _ in range(32):
out = model(dummy_input)
#print(out.size(), out)
print('speedup elapsed time: ', time.time() - start)
return
use_speedup_out = model(dummy_input)
print('elapsed time when use speedup: ', time.time() - start)
if compare_results:
if torch.allclose(use_mask_out, use_speedup_out, atol=1e-07):
print('the outputs from use_mask and use_speedup are the same')
else:
raise RuntimeError('the outputs from use_mask and use_speedup are different')

if __name__ == '__main__':
parser = argparse.ArgumentParser("speedup")
parser.add_argument("--example_name", type=str, default="slim", help="the name of pruning example")
parser.add_argument("--masks_file", type=str, default=None, help="the path of the masks file")
parser.add_argument("--model_checkpoint", type=str, default=None, help="the path of checkpointed model")
args = parser.parse_args()

if args.example_name == 'slim':
if args.masks_file is None:
args.masks_file = 'mask_vgg19_cifar10.pth'
slim_speedup(args.masks_file, args.model_checkpoint)
elif args.example_name == 'fpgm':
if args.masks_file is None:
args.masks_file = 'mask.pth'
fpgm_speedup(args.masks_file, args.model_checkpoint)
elif args.example_name == 'l1filter':
if args.masks_file is None:
args.masks_file = 'mask_vgg16_cifar10.pth'
l1filter_speedup(args.masks_file, args.model_checkpoint)
elif args.example_name == 'apoz':
if args.masks_file is None:
args.masks_file = 'mask_vgg16_cifar10.pth'
apoz_speedup(args.masks_file, args.model_checkpoint)

if args.example_name != 'all':
if args.masks_file is not None:
config[args.example_name]['masks_file'] = args.masks_file
if not os.path.exists(config[args.example_name]['masks_file']):
msg = '{} does not exist! You should specify masks_file correctly, '
msg += 'or use default one which is generated by model_prune_torch.py'
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved
raise RuntimeError(msg.format(config[args.example_name]['masks_file']))
model_inference(config[args.example_name])
else:
raise ValueError('unsupported example_name: {}'.format(args.example_name))
model_inference(config['fpgm'])
model_inference(config['slim'])
model_inference(config['l1filter'])
model_inference(config['apoz'])
6 changes: 4 additions & 2 deletions src/sdk/pynni/nni/compression/speedup/torch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class ModelSpeedup:
This class is to speedup the model with provided weight mask
"""

def __init__(self, model, dummy_input, masks_file):
def __init__(self, model, dummy_input, masks_file, map_location=None):
"""
Parameters
----------
Expand All @@ -80,10 +80,12 @@ def __init__(self, model, dummy_input, masks_file):
The dummy input for ```jit.trace```, users should put it on right device before pass in
masks_file : str
The path of user provided mask file
map_location : str
the device on which masks are placed, same to map_location in ```torch.load```
"""
self.bound_model = model
self.dummy_input = dummy_input
self.masks = torch.load(masks_file)
self.masks = torch.load(masks_file, map_location)
self.is_training = model.training
# to obtain forward graph, model should be in ```eval``` mode
if self.is_training:
Expand Down
63 changes: 11 additions & 52 deletions src/sdk/pynni/nni/compression/torch/apply_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,68 +3,27 @@

import logging
import torch
from .compressor import Pruner

logger = logging.getLogger('torch apply compression')

def apply_compression_results(model, masks_file):
def apply_compression_results(model, masks_file, map_location=None):
"""
Apply the masks from ```masks_file``` to the model
Note: this API is for inference, because it simply multiplies weights with
corresponding masks when this API is called.

Parameters
----------
model : torch.nn.module
The model to be compressed
masks_file : str
The path of the mask file
map_location : str
the device on which masks are placed, same to map_location in ```torch.load```
"""
apply_comp = ApplyCompression(model, masks_file)
apply_comp.compress()

class ApplyCompression(Pruner):
"""
This class is not to generate masks, but applying existing masks
"""

def __init__(self, model, masks_file):
"""
Parameters
----------
model : torch.nn.module
Model to be masked
masks_file : str
The path of user provided mask file
"""
self.bound_model = model
self.masks = torch.load(masks_file)
for module_name in self.masks:
print('module_name: ', module_name)
config_list = self._build_config()
super().__init__(model, config_list)

def _build_config(self):
op_names = []
for module_name in self.masks:
op_names.append(module_name)
return [{'sparsity': 1, 'op_types': ['default', 'BatchNorm2d'], 'op_names': op_names}]

def calc_mask(self, layer, config, **kwargs):
"""
Directly return the corresponding mask

Parameters
----------
layer : LayerInfo
The layer to be pruned
config : dict
Pruning configurations for this weight
kwargs : dict
Auxiliary information

Returns
-------
dict
Mask of the layer
"""
assert layer.name in self.masks
return self.masks[layer.name]
masks = torch.load(masks_file, map_location)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to wrap the whole function with with torch.no_grad()?

for name, module in model.named_modules():
if name in masks:
module.weight.data = module.weight.data.mul_(masks[name]['weight'])
if hasattr(module, 'bias') and module.bias is not None and 'bias' in masks[name]:
module.bias.data = module.bias.data.mul_(masks[name]['bias'])