Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Speedup supports channel pruning (#2906)
Browse files Browse the repository at this point in the history
  • Loading branch information
chicm-ms authored Oct 9, 2020
1 parent f43719a commit 8bc74a2
Show file tree
Hide file tree
Showing 9 changed files with 639 additions and 144 deletions.
32 changes: 32 additions & 0 deletions src/sdk/pynni/nni/_graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,36 @@ def _extract_cat_info(self, node_group, cpp_node):
cat_info['in_shape'] = input_shapes
return cat_info

def _extract_linear_shape_info(self, node_group):
"""
Extract linear shape input/output tensor shape info from its aten::addmm op.
Parameters
----------
node_group : NodePyGroup
NodePyGroup object associated with the linear module.
Returns
-------
dict
Include shape of input tensor and shape of output tensor
"""
for cpp_node in node_group.node_cpps:
if cpp_node.kind() == 'aten::addmm':
# https://github.com/pytorch/pytorch/blob/1.6/torch/nn/functional.py#L1682
# inputs of aten::addmm:
# inputs[0] is bias
# inputs[1] is input data
# inputs[2] is weight
t_input = list(cpp_node.inputs())[1]
t_output = cpp_node.output()
assert isinstance(t_input.type(), torch._C.TensorType)
assert isinstance(t_output.type(), torch._C.TensorType)
in_shape = t_input.type().sizes()
out_shape = t_output.type().sizes()
return {'in_shape': in_shape, 'out_shape': out_shape}
return None

def _extract_shape_info(self, node):
"""
Extract the shape information of ```aten::view``` node
Expand Down Expand Up @@ -701,6 +731,8 @@ def _extract_auxiliary_info(self):
cpp_node = list(filter(lambda x: x.kind() == node_group.op_type,
node_group.node_cpps))[0]
node_group.auxiliary = self._extract_shape_info(cpp_node)
elif node_group.op_type == 'Linear':
node_group.auxiliary = self._extract_linear_shape_info(node_group)
elif node_group.op_type == CAT_KIND:
# get the detail information for cat func
cpp_node = list(filter(lambda x: x.kind() == node_group.op_type,
Expand Down
12 changes: 8 additions & 4 deletions src/sdk/pynni/nni/compression/torch/pruning/one_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT license.

import logging
from schema import And, Optional
from schema import And, Optional, SchemaError
from nni._graph_utils import TorchModuleGraph
from nni.compression.torch.utils.shape_dependency import ChannelDependency, GroupDependency
from .constants import MASKER_DICT
Expand Down Expand Up @@ -186,12 +186,16 @@ def update_mask(self):

def validate_config(self, model, config_list):
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
'op_types': ['Conv2d'],
Optional('op_names'): [str]
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
Optional('op_types'): ['Conv2d'],
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)

schema.validate(config_list)
for config in config_list:
if 'exclude' not in config and 'sparsity' not in config:
raise SchemaError('Either sparisty or exclude must be specified!')

def _dependency_calc_mask(self, wrappers, channel_dsets, wrappers_idx=None):
"""
Expand Down
19 changes: 13 additions & 6 deletions src/sdk/pynni/nni/compression/torch/speedup/compress_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,19 @@ def replace_conv2d(conv, mask):
else:
out_channels_index = mask.output_mask.mask_index[1]
out_channels = out_channels_index.size()[0]

_logger.debug("replace conv2d with in_channels: %d, out_channels: %d", in_channels, out_channels)
groups = conv.groups
if conv.in_channels == conv.out_channels == conv.groups:
# remove groups for depthwise layers
assert in_channels == out_channels
groups = in_channels
_logger.debug("replace conv2d %s with in_channels: %d, out_channels: %d", mask.module_name, in_channels, out_channels)
new_conv = torch.nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
dilation=conv.dilation,
groups=conv.groups,
groups=groups,
bias=conv.bias is not None,
padding_mode=conv.padding_mode)

Expand All @@ -142,13 +146,16 @@ def replace_conv2d(conv, mask):
# channal is also divided into serveral groups and each group
# filter may have different input channel indexes.
input_step = int(conv.in_channels / conv.groups)
in_channels_group = int(in_channels / conv.groups)
filter_step = int(out_channels / conv.groups)
if mask.input_mask is not None:
in_channels_group = int(in_channels / groups)
filter_step = int(out_channels / groups)
if mask.input_mask is not None and not (in_channels == out_channels == groups):
for groupid in range(conv.groups):
start = groupid * input_step
end = (groupid + 1) * input_step
current_input_index = list(filter(lambda x: start <= x and x < end, in_channels_index.tolist()))
if not current_input_index:
# there is no kept channel in current group
continue
# shift the global index into the group index
current_input_index = [x-start for x in current_input_index]
# if the groups is larger than 1, the input channels of each
Expand Down
46 changes: 18 additions & 28 deletions src/sdk/pynni/nni/compression/torch/speedup/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,13 @@
import logging
import torch
from nni.compression.torch.utils.mask_conflict import fix_mask_conflict
from nni.compression.torch.utils.utils import get_module_by_name
from .compress_modules import replace_module
from .infer_shape import ModuleMasks, infer_from_mask, infer_from_inshape, infer_from_outshape
from .infer_shape import ModuleMasks, infer_from_mask, infer_from_inshape, infer_from_outshape, set_conv_prune_dim

_logger = logging.getLogger(__name__)


def get_module_by_name(model, module_name):
"""
Get a module specified by its module name
Parameters
----------
model : pytorch model
the pytorch model from which to get its module
module_name : str
the name of the required module
Returns
-------
module, module
the parent module of the required module, the required module
"""
name_list = module_name.split(".")
for name in name_list[:-1]:
model = getattr(model, name)
leaf_module = getattr(model, name_list[-1])
return model, leaf_module

class ModelSpeedup:
"""
This class is to speedup the model with provided weight mask
Expand Down Expand Up @@ -87,7 +66,8 @@ def infer_module_mask(self, module_name, last_module, mask=None, in_shape=None,
if module_name in self.inferred_masks:
module_masks = self.inferred_masks[module_name]
else:
module_masks = ModuleMasks(module_name)
_, m = get_module_by_name(self.bound_model, module_name)
module_masks = ModuleMasks(module_name, m)
self.inferred_masks[module_name] = module_masks

m_type = self.torch_graph.name_to_node[module_name].op_type
Expand All @@ -98,7 +78,12 @@ def infer_module_mask(self, module_name, last_module, mask=None, in_shape=None,
raise RuntimeError(
"Has not supported infering input/output shape from mask for module/function: `{}`, {}"
.format(m_type, module_name))
input_cmask, output_cmask = infer_from_mask[m_type](module_masks, mask)
if m_type in ['Linear']:
input_cmask, output_cmask = infer_from_mask[m_type](
module_masks, mask, self.torch_graph.name_to_node[module_name].auxiliary
)
else:
input_cmask, output_cmask = infer_from_mask[m_type](module_masks, mask)
if in_shape is not None:
_logger.debug("in_shape is not None")
if not m_type in infer_from_inshape:
Expand All @@ -124,7 +109,10 @@ def infer_module_mask(self, module_name, last_module, mask=None, in_shape=None,
raise RuntimeError(
"Has not supported infering input shape from output shape for module/function: `{}`, {}"
.format(m_type, module_name))
input_cmask = infer_from_outshape[m_type](module_masks, out_shape)
if m_type in ['aten::view', 'aten::flatten', 'aten::mean', 'aten::reshape']:
input_cmask = infer_from_outshape[m_type](module_masks, out_shape, self.torch_graph.name_to_node[module_name].auxiliary)
else:
input_cmask = infer_from_outshape[m_type](module_masks, out_shape)

if input_cmask:
predecessors = self.torch_graph.find_predecessors(module_name)
Expand Down Expand Up @@ -178,7 +166,6 @@ def replace_compressed_modules(self):
else:
raise RuntimeError("Unsupported node type: {}".format(g_node.type))


def speedup_model(self):
"""
There are basically two steps:
Expand All @@ -187,8 +174,11 @@ def speedup_model(self):
"""
training = self.bound_model.training
_logger.info("start to speed up the model")

_logger.info("fix the mask conflict of the interdependent layers")
fix_mask_conflict(self.masks, self.bound_model, self.dummy_input)
_, conv_prune_dim = fix_mask_conflict(self.masks, self.bound_model, self.dummy_input)
set_conv_prune_dim(conv_prune_dim)

_logger.info("infer module masks...")
self.infer_modules_masks()
_logger.info("replace compressed modules...")
Expand Down
Loading

0 comments on commit 8bc74a2

Please sign in to comment.