Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Speedup enhancement #2719

Merged
merged 8 commits into from
Aug 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/sdk/pynni/nni/compression/torch/speedup/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,14 @@ def infer_modules_masks(self):
"""
for module_name, mask in self.masks.items():
_logger.debug('Start mask inference from %s', module_name)
if module_name not in self.torch_graph.name_to_node:
# this module is not traced in the torch_graph,
# jit.trace only correctly records functions and
# modules which are not data dependent (e.g., do
# not have conditionals on data in tensors)
# so, if a node is not traced, we just skip it.
zheng-ningxin marked this conversation as resolved.
Show resolved Hide resolved
_logger.warning('%s has mask, but not found in the traced graph, just skip it.', module_name)
continue
self.infer_module_mask(module_name, None, mask=mask)

def replace_compressed_modules(self):
Expand Down
27 changes: 18 additions & 9 deletions src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def __repr__(self):
'ReLU': lambda module_masks, mask: relu_inshape(module_masks, mask),
'ReLU6': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::relu': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::relu_': lambda module_masks, mask: relu_inshape(module_masks, mask),
'Conv2d': lambda module_masks, mask: conv2d_inshape(module_masks, mask),
'MaxPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'aten::max_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
Expand All @@ -241,7 +242,8 @@ def __repr__(self):
'aten::cat': lambda module_mask, mask, cat_info, last_visited: cat_inshape(module_mask, mask, cat_info, last_visited),
'aten::mean': lambda module_masks, mask, shape: mean_inshape(module_masks, mask, shape),
'Dropout': lambda module_masks, mask: dropout_inshape(module_masks, mask),
'Dropout2d': lambda module_masks, mask: dropout_inshape(module_masks, mask)
'Dropout2d': lambda module_masks, mask: dropout_inshape(module_masks, mask),
'aten::dropout': lambda module_masks, mask: dropout_inshape(module_masks, mask)
}

"""
Expand All @@ -258,8 +260,14 @@ def dropout_inshape(module_masks, mask):
return module_masks.output_mask
# if alreay visited
assert module_masks.input_mask <= mask
if module_masks.input_mask == mask:
return None
# It should be the same, we pass the masks by the reference(not the value),
# so they acutually are two references of the same object(mask,
# module_masks.input_mask). So we should continue pass the mask
# to the following nodes even module_masks.input_mask == mask.
# if pass the mask by copy.deepcopy(), then we can stop when
# module_masks.input_mask == mask.
# if module_masks.input_mask == mask:
# return None
module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask)
return module_masks.output_mask
Expand Down Expand Up @@ -413,7 +421,8 @@ def linear_inshape(module_masks, mask):
"""
assert isinstance(mask, CoarseMask)
assert mask.mask_index[0] is None
assert module_masks.input_mask is None
if module_masks.input_mask is not None:
assert module_masks.input_mask <= mask
module_masks.set_input_mask(mask)
return None

Expand Down Expand Up @@ -451,7 +460,10 @@ def view_inshape(module_masks, mask, shape):
assert mask.mask_index[0] is None
assert mask.mask_index[2] is None
assert mask.mask_index[3] is None
assert module_masks.input_mask is None
# due to the cat operation, the same node may be
# accessed more than once
if module_masks.input_mask is not None:
assert module_masks.input_mask <= mask
module_masks.set_input_mask(mask)
output_cmask = CoarseMask(num_dim=2)
index = []
Expand Down Expand Up @@ -535,12 +547,9 @@ def relu_inshape(module_masks, mask):
The mask of its output tensor
"""
assert isinstance(mask, CoarseMask)
# TODO: double check this assert, is it possible that a module is passed twice
if module_masks.input_mask is not None:
# check if has a mask conflict
assert module_masks.input_mask == mask
# No need to pass the mask again
return None
assert module_masks.input_mask <= mask
# assert module_masks.input_mask is None, "A relu op can only be processed once"
module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask)
Expand Down
6 changes: 3 additions & 3 deletions src/sdk/pynni/tests/test_model_speedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,18 +145,18 @@ def test_speedup_bigmodel(self):
assert model.backbone2.fc1.in_features == int(orig_model.backbone2.fc1.in_features * SPARSITY)

def test_speedup_integration(self):
for model_name in ['resnet18', 'squeezenet1_1', 'mobilenet_v2']:
for model_name in ['resnet18', 'squeezenet1_1', 'mobilenet_v2', 'densenet121', 'inception_v3']:
zheng-ningxin marked this conversation as resolved.
Show resolved Hide resolved
Model = getattr(models, model_name)
net = Model(pretrained=True, progress=False).to(device)
speedup_model = Model().to(device)
net.eval() # this line is necessary
speedup_model.eval()
# random generate the prune config for the pruner
cfgs = generate_random_sparsity(net)
pruner = L1FilterPruner(net, cfgs)
pruner.compress()
pruner.export_model(MODEL_FILE, MASK_FILE)
pruner._unwrap_model()
speedup_model = Model().to(device)
speedup_model.eval()
state_dict = torch.load(MODEL_FILE)
speedup_model.load_state_dict(state_dict)
zero_bn_bias(net)
Expand Down