Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

support prelu to speedup model #3842

Merged
merged 1 commit into from
Jul 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions nni/compression/pytorch/speedup/compress_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
'AvgPool2d': lambda module, mask: no_replace(module, mask),
'AdaptiveAvgPool2d': lambda module, mask: no_replace(module, mask),
'ReLU': lambda module, mask: no_replace(module, mask),
'PReLU': lambda module, mask: replace_prelu(module, mask),
'ReLU6': lambda module, mask: no_replace(module, mask),
'Sigmoid': lambda module, mask: no_replace(module, mask),
'Linear': lambda module, mask: replace_linear(module, mask),
Expand All @@ -31,6 +32,31 @@ def no_replace(module, mask):
_logger.debug("no need to replace")
return module

def replace_prelu(norm, mask):
"""
Parameters
----------
norm : torch.nn.BatchNorm2d
The prelu module to be replace
mask : ModuleMasks
The masks of this module

Returns
-------
torch.nn.PReLU
The new prelu module
"""
assert isinstance(mask, ModuleMasks)
assert 'weight' in mask.param_masks
index = mask.param_masks['weight'].mask_index[0]
num_features = index.size()[0]
# _logger.debug("replace prelu with num_features: %d", num_features)
if num_features == 0:
return torch.nn.Identity()
new_norm = torch.nn.PReLU(num_features)
# assign weights
new_norm.weight.data = torch.index_select(norm.weight.data, 0, index)
return new_norm

def replace_linear(linear, mask):
"""
Expand Down
58 changes: 58 additions & 0 deletions nni/compression/pytorch/speedup/infer_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def __repr__(self):
infer_from_inshape = {
'ReLU': lambda module_masks, mask: relu_inshape(module_masks, mask),
'ReLU6': lambda module_masks, mask: relu_inshape(module_masks, mask),
'PReLU': lambda module_masks, mask: prelu_inshape(module_masks, mask),
'Sigmoid': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::relu': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::tanh': lambda module_masks, mask: relu_inshape(module_masks, mask),
Expand Down Expand Up @@ -293,6 +294,7 @@ def __repr__(self):
'AdaptiveAvgPool2d': lambda module_masks, mask: maxpool2d_outshape(module_masks, mask),

'ReLU': lambda module_masks, mask: relu_outshape(module_masks, mask),
'PReLU': lambda module_masks, mask: prelu_outshape(module_masks, mask),
'ReLU6': lambda module_masks, mask: relu_outshape(module_masks, mask),
'aten::relu': lambda module_masks, mask: relu_outshape(module_masks, mask),
'aten::tanh': lambda module_masks, mask: relu_outshape(module_masks, mask),
Expand Down Expand Up @@ -735,6 +737,62 @@ def maxpool2d_outshape(module_masks, mask):
module_masks.set_output_mask(mask)
return mask

def prelu_inshape(module_masks, mask):
"""
We assume only the second dimension has coarse grained mask

Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the PReLU
mask : CoarseMask
The mask of its input tensor

Returns
-------
CoarseMask
The mask of its output tensor
"""
assert isinstance(mask, CoarseMask)
assert mask.mask_index[1] is not None
assert mask.mask_index[0] is None
assert mask.mask_index[2] is None
assert mask.mask_index[3] is None
module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask)
weight_cmask = CoarseMask(num_dim=1)
weight_cmask.add_index_mask(dim=0, index=mask.mask_index[1])
module_masks.set_param_masks('weight', weight_cmask)
return mask

def prelu_outshape(module_masks, mask):
"""
We assume only the second dimension has coarse grained mask

Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the PReLU
mask : CoarseMask
The mask of its input tensor

Returns
-------
CoarseMask
The mask of its output tensor
"""
assert isinstance(mask, CoarseMask)
assert mask.mask_index[1] is not None
assert mask.mask_index[0] is None
assert mask.mask_index[2] is None
assert mask.mask_index[3] is None

weight_cmask = CoarseMask(num_dim=4)
weight_cmask.add_index_mask(dim=0, index=mask.mask_index[1])
module_masks.set_param_masks('weight', weight_cmask)

return mask


def relu_inshape(module_masks, mask):
"""
Expand Down