Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Fix bug for speedup module and enhance the Ut for speedup #3279

Merged
merged 5 commits into from
Jan 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 12 additions & 16 deletions nni/compression/pytorch/speedup/infer_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,23 +891,18 @@ def convert_to_coarse_mask(mask, dim=0):
sum_idx = (1, 2, 3) if dim == 0 else (0, 2, 3)
index = torch.nonzero(weight_mask.abs().sum(
sum_idx) != 0, as_tuple=True)[0]
if len(index) == weight_mask.shape[dim]: # full mask
index = None
J-shang marked this conversation as resolved.
Show resolved Hide resolved

if index is None:
return None, None, None
zheng-ningxin marked this conversation as resolved.
Show resolved Hide resolved
else:
index = index.long().to(weight_mask.device)
weight_cmask = CoarseMask(num_dim=4)
weight_cmask.add_index_mask(dim=dim, index=index)
bias_cmask = None
if dim == 0 and 'bias' in mask and mask['bias'] is not None:
bias_index = torch.nonzero(mask['bias'], as_tuple=True)[0]
assert torch.all(torch.eq(index, bias_index)), \
"bias mask should be consistent with weight mask"
bias_cmask = CoarseMask(num_dim=1)
bias_cmask.add_index_mask(dim=0, index=bias_index)
return index, weight_cmask, bias_cmask
index = index.long().to(weight_mask.device)
weight_cmask = CoarseMask(num_dim=4)
weight_cmask.add_index_mask(dim=dim, index=index)
bias_cmask = None
if dim == 0 and 'bias' in mask and mask['bias'] is not None:
bias_index = torch.nonzero(mask['bias'], as_tuple=True)[0]
assert torch.all(torch.eq(index, bias_index)), \
"bias mask should be consistent with weight mask"
bias_cmask = CoarseMask(num_dim=1)
bias_cmask.add_index_mask(dim=0, index=bias_index)
return index, weight_cmask, bias_cmask

index, weight_cmask, bias_cmask = convert_to_coarse_mask(
mask, dim=conv_prune_dim)
Expand Down Expand Up @@ -962,6 +957,7 @@ def conv2d_inshape(module_masks, mask):
# the same conv layer may be accessed more
# than once, such as a concat operation.
# mask conflict should be solved by fix_mask_conflict before speedup

assert module_masks.input_mask == mask

# shape changes pass through depths wise conv layers
Expand Down
9 changes: 8 additions & 1 deletion nni/compression/pytorch/utils/mask_conflict.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
# if the input is the path of the mask_file
assert os.path.exists(masks)
masks = torch.load(masks)
assert len(masks) > 0, 'Mask tensor cannot be empty'
# if the user uses the model and dummy_input to trace the model, we
# should get the traced model handly, so that, we only trace the
# model once, GroupMaskConflict and ChannelMaskConflict will reuse
Expand Down Expand Up @@ -127,6 +128,7 @@ def fix_mask(self):
for layer in layers:
if layer in self.masks:
continue

module = name_to_module[layer]
w_shape = module.weight.data.size()
w_mask = torch.ones(w_shape).to(device)
Expand All @@ -136,6 +138,7 @@ def fix_mask(self):
b_shape = module.bias.data.size()
b_mask = torch.ones(b_shape).to(device)
self.masks[layer] = {'weight': w_mask, 'bias': b_mask}

return self.masks


Expand Down Expand Up @@ -250,6 +253,10 @@ def fix_mask(self):
self.model, self.dummy_input, self.traced)
depen_sets = channel_depen.dependency_sets
sum_idx = (1, 2, 3) if self.conv_prune_dim == 0 else (0, 2, 3)

(_tmp_name, _tmp_tensor) = list(self.masks.items())[0]
device = _tmp_tensor['weight'].device

for dset in depen_sets:
if len(dset) <= 1:
continue
Expand Down Expand Up @@ -301,7 +308,7 @@ def fix_mask(self):

for i, dim_mask in enumerate(channel_masks):
if dim_mask is None:
channel_masks[i] = torch.ones(num_channels).int()
channel_masks[i] = torch.ones(num_channels).int().to(device)

# merge masks with 'or'
merged_channel_mask = channel_masks[0].clone()
Expand Down
113 changes: 69 additions & 44 deletions test/ut/sdk/test_model_speedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the MIT license.

import os
import psutil
import sys
import numpy as np
import torch
Expand Down Expand Up @@ -128,6 +129,18 @@ def generate_random_sparsity(model):
'sparsity': sparsity})
return cfg_list

def generate_random_sparsity_v2(model):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe better to use ratio as a parameter, like def generate_random_sparsity(model, layer_ratio):, just personal opinion, the current implementation is fine to me.

"""
Only select 50% layers to prune.
"""
cfg_list = []
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
if np.random.uniform(0, 1.0) > 0.5:
sparsity = np.random.uniform(0.5, 0.99)
cfg_list.append({'op_types': ['Conv2d'], 'op_names': [name],
'sparsity': sparsity})
return cfg_list

def zero_bn_bias(model):
with torch.no_grad():
Expand Down Expand Up @@ -292,52 +305,62 @@ def test_convtranspose_model(self):
# Example: https://msrasrg.visualstudio.com/NNIOpenSource/_build/results?buildId=16282

def test_speedup_integration(self):
for model_name in ['resnet18', 'squeezenet1_1',
'mobilenet_v2', 'densenet121',
# skip this test on windows(7GB mem available) due to memory limit
# Note: hack trick, may be updated in the future
if 'win' in sys.platform or 'Win'in sys.platform:
print('Skip test_speedup_integration on windows due to memory limit!')
return

Gen_cfg_funcs = [generate_random_sparsity, generate_random_sparsity_v2]

for model_name in ['resnet18', 'mobilenet_v2', 'squeezenet1_1', 'densenet121' , 'densenet169',
# 'inception_v3' inception is too large and may fail the pipeline
'densenet169', 'resnet50']:
kwargs = {
'pretrained': True
}
if model_name == 'resnet50':
# testing multiple groups
'resnet50']:

for gen_cfg_func in Gen_cfg_funcs:
kwargs = {
'pretrained': False,
'groups': 4
'pretrained': True
}
if model_name == 'resnet50':
# testing multiple groups
kwargs = {
'pretrained': False,
'groups': 4
}
Model = getattr(models, model_name)
net = Model(**kwargs).to(device)
speedup_model = Model(**kwargs).to(device)
net.eval() # this line is necessary
speedup_model.eval()
# random generate the prune config for the pruner
cfgs = gen_cfg_func(net)
print("Testing {} with compression config \n {}".format(model_name, cfgs))
pruner = L1FilterPruner(net, cfgs)
pruner.compress()
pruner.export_model(MODEL_FILE, MASK_FILE)
pruner._unwrap_model()
state_dict = torch.load(MODEL_FILE)
speedup_model.load_state_dict(state_dict)
zero_bn_bias(net)
zero_bn_bias(speedup_model)

data = torch.ones(BATCH_SIZE, 3, 128, 128).to(device)
ms = ModelSpeedup(speedup_model, data, MASK_FILE)
ms.speedup_model()

speedup_model.eval()

ori_out = net(data)
speeded_out = speedup_model(data)
ori_sum = torch.sum(ori_out).item()
speeded_sum = torch.sum(speeded_out).item()
print('Sum of the output of %s (before speedup):' %
model_name, ori_sum)
print('Sum of the output of %s (after speedup):' %
model_name, speeded_sum)
assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \
(abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)

Model = getattr(models, model_name)
net = Model(**kwargs).to(device)
speedup_model = Model(**kwargs).to(device)
net.eval() # this line is necessary
speedup_model.eval()
# random generate the prune config for the pruner
cfgs = generate_random_sparsity(net)
pruner = L1FilterPruner(net, cfgs)
pruner.compress()
pruner.export_model(MODEL_FILE, MASK_FILE)
pruner._unwrap_model()
state_dict = torch.load(MODEL_FILE)
speedup_model.load_state_dict(state_dict)
zero_bn_bias(net)
zero_bn_bias(speedup_model)

data = torch.ones(BATCH_SIZE, 3, 128, 128).to(device)
ms = ModelSpeedup(speedup_model, data, MASK_FILE)
ms.speedup_model()

speedup_model.eval()

ori_out = net(data)
speeded_out = speedup_model(data)
ori_sum = torch.sum(ori_out).item()
speeded_sum = torch.sum(speeded_out).item()
print('Sum of the output of %s (before speedup):' %
model_name, ori_sum)
print('Sum of the output of %s (after speedup):' %
model_name, speeded_sum)
assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \
(abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)

def test_channel_prune(self):
orig_net = resnet18(num_classes=10).to(device)
Expand Down Expand Up @@ -369,8 +392,10 @@ def test_channel_prune(self):
(abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)

def tearDown(self):
os.remove(MODEL_FILE)
os.remove(MASK_FILE)
if os.path.exists(MODEL_FILE):
os.remove(MODEL_FILE)
if os.path.exists(MASK_FILE):
os.remove(MASK_FILE)


if __name__ == '__main__':
Expand Down