Skip to content

Commit

Permalink
Merge pull request #118 from microsoft/master
Browse files Browse the repository at this point in the history
pull
  • Loading branch information
chicm-ms authored Oct 9, 2020
2 parents 3c5cef2 + 8bc74a2 commit 85a879e
Show file tree
Hide file tree
Showing 30 changed files with 1,034 additions and 159 deletions.
1 change: 1 addition & 0 deletions deployment/pypi/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
'ruamel.yaml',
'psutil',
'requests',
'responses',
'astor',
'PythonWebHDFS',
'hyperopt==0.1.2',
Expand Down
2 changes: 1 addition & 1 deletion examples/trials/ga_squad/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
tensorflow==1.15.2
tensorflow==1.15.4
2 changes: 1 addition & 1 deletion examples/trials/network_morphism/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
numpy==1.14.2
tensorflow==1.15.2
tensorflow==1.15.4
torchvision==0.2.1
Keras==2.3.1
torch==0.4.1
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def read(fname):
'psutil',
'ruamel.yaml',
'requests',
'responses',
'scipy',
'schema',
'PythonWebHDFS',
Expand Down
4 changes: 3 additions & 1 deletion src/nni_manager/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@
"hoek": "^4.2.1",
"js-yaml": "^3.13.1",
"npm": "^6.13.4",
"acorn": ">=7.1.1"
"acorn": ">=7.1.1",
"node-forge": "^0.10.0",
"dot-prop": "^4.2.1"
},
"engines": {
"node": ">=10.0.0"
Expand Down
14 changes: 8 additions & 6 deletions src/nni_manager/yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -1427,9 +1427,10 @@ doctrine@^3.0.0:
dependencies:
esutils "^2.0.2"

dot-prop@^4.1.0:
version "4.2.0"
resolved "https://registry.yarnpkg.com/dot-prop/-/dot-prop-4.2.0.tgz#1f19e0c2e1aa0e32797c49799f2837ac6af69c57"
dot-prop@^4.1.0, dot-prop@^4.2.1:
version "4.2.1"
resolved "https://registry.yarnpkg.com/dot-prop/-/dot-prop-4.2.1.tgz#45884194a71fc2cda71cbb4bceb3a4dd2f433ba4"
integrity sha512-l0p4+mIuJIua0mhxGoh4a+iNL9bmeK5DvnSVQa6T0OhrVmaEa1XScX5Etc673FePCJOArq/4Pa2cLGODUWTPOQ==
dependencies:
is-obj "^1.0.0"

Expand Down Expand Up @@ -3349,9 +3350,10 @@ node-fetch-npm@^2.0.2:
json-parse-better-errors "^1.0.0"
safe-buffer "^5.1.1"

node-forge@^0.7.6:
version "0.7.6"
resolved "https://registry.yarnpkg.com/node-forge/-/node-forge-0.7.6.tgz#fdf3b418aee1f94f0ef642cd63486c77ca9724ac"
node-forge@^0.10.0, node-forge@^0.7.6:
version "0.10.0"
resolved "https://registry.yarnpkg.com/node-forge/-/node-forge-0.10.0.tgz#32dea2afb3e9926f02ee5ce8794902691a676bf3"
integrity sha512-PPmu8eEeG9saEUvI97fm4OYxXVB6bFvyNTyiUOBichBpFG8A1Ljw3bY62+5oOjDEMHRnd0Y7HQ+x7uzxOzC6JA==

node-gyp@^5.0.2, node-gyp@^5.0.5:
version "5.0.7"
Expand Down
32 changes: 32 additions & 0 deletions src/sdk/pynni/nni/_graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,36 @@ def _extract_cat_info(self, node_group, cpp_node):
cat_info['in_shape'] = input_shapes
return cat_info

def _extract_linear_shape_info(self, node_group):
"""
Extract linear shape input/output tensor shape info from its aten::addmm op.
Parameters
----------
node_group : NodePyGroup
NodePyGroup object associated with the linear module.
Returns
-------
dict
Include shape of input tensor and shape of output tensor
"""
for cpp_node in node_group.node_cpps:
if cpp_node.kind() == 'aten::addmm':
# https://github.com/pytorch/pytorch/blob/1.6/torch/nn/functional.py#L1682
# inputs of aten::addmm:
# inputs[0] is bias
# inputs[1] is input data
# inputs[2] is weight
t_input = list(cpp_node.inputs())[1]
t_output = cpp_node.output()
assert isinstance(t_input.type(), torch._C.TensorType)
assert isinstance(t_output.type(), torch._C.TensorType)
in_shape = t_input.type().sizes()
out_shape = t_output.type().sizes()
return {'in_shape': in_shape, 'out_shape': out_shape}
return None

def _extract_shape_info(self, node):
"""
Extract the shape information of ```aten::view``` node
Expand Down Expand Up @@ -701,6 +731,8 @@ def _extract_auxiliary_info(self):
cpp_node = list(filter(lambda x: x.kind() == node_group.op_type,
node_group.node_cpps))[0]
node_group.auxiliary = self._extract_shape_info(cpp_node)
elif node_group.op_type == 'Linear':
node_group.auxiliary = self._extract_linear_shape_info(node_group)
elif node_group.op_type == CAT_KIND:
# get the detail information for cat func
cpp_node = list(filter(lambda x: x.kind() == node_group.op_type,
Expand Down
12 changes: 8 additions & 4 deletions src/sdk/pynni/nni/compression/torch/pruning/one_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT license.

import logging
from schema import And, Optional
from schema import And, Optional, SchemaError
from nni._graph_utils import TorchModuleGraph
from nni.compression.torch.utils.shape_dependency import ChannelDependency, GroupDependency
from .constants import MASKER_DICT
Expand Down Expand Up @@ -186,12 +186,16 @@ def update_mask(self):

def validate_config(self, model, config_list):
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
'op_types': ['Conv2d'],
Optional('op_names'): [str]
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
Optional('op_types'): ['Conv2d'],
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)

schema.validate(config_list)
for config in config_list:
if 'exclude' not in config and 'sparsity' not in config:
raise SchemaError('Either sparisty or exclude must be specified!')

def _dependency_calc_mask(self, wrappers, channel_dsets, wrappers_idx=None):
"""
Expand Down
19 changes: 13 additions & 6 deletions src/sdk/pynni/nni/compression/torch/speedup/compress_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,19 @@ def replace_conv2d(conv, mask):
else:
out_channels_index = mask.output_mask.mask_index[1]
out_channels = out_channels_index.size()[0]

_logger.debug("replace conv2d with in_channels: %d, out_channels: %d", in_channels, out_channels)
groups = conv.groups
if conv.in_channels == conv.out_channels == conv.groups:
# remove groups for depthwise layers
assert in_channels == out_channels
groups = in_channels
_logger.debug("replace conv2d %s with in_channels: %d, out_channels: %d", mask.module_name, in_channels, out_channels)
new_conv = torch.nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
dilation=conv.dilation,
groups=conv.groups,
groups=groups,
bias=conv.bias is not None,
padding_mode=conv.padding_mode)

Expand All @@ -142,13 +146,16 @@ def replace_conv2d(conv, mask):
# channal is also divided into serveral groups and each group
# filter may have different input channel indexes.
input_step = int(conv.in_channels / conv.groups)
in_channels_group = int(in_channels / conv.groups)
filter_step = int(out_channels / conv.groups)
if mask.input_mask is not None:
in_channels_group = int(in_channels / groups)
filter_step = int(out_channels / groups)
if mask.input_mask is not None and not (in_channels == out_channels == groups):
for groupid in range(conv.groups):
start = groupid * input_step
end = (groupid + 1) * input_step
current_input_index = list(filter(lambda x: start <= x and x < end, in_channels_index.tolist()))
if not current_input_index:
# there is no kept channel in current group
continue
# shift the global index into the group index
current_input_index = [x-start for x in current_input_index]
# if the groups is larger than 1, the input channels of each
Expand Down
46 changes: 18 additions & 28 deletions src/sdk/pynni/nni/compression/torch/speedup/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,13 @@
import logging
import torch
from nni.compression.torch.utils.mask_conflict import fix_mask_conflict
from nni.compression.torch.utils.utils import get_module_by_name
from .compress_modules import replace_module
from .infer_shape import ModuleMasks, infer_from_mask, infer_from_inshape, infer_from_outshape
from .infer_shape import ModuleMasks, infer_from_mask, infer_from_inshape, infer_from_outshape, set_conv_prune_dim

_logger = logging.getLogger(__name__)


def get_module_by_name(model, module_name):
"""
Get a module specified by its module name
Parameters
----------
model : pytorch model
the pytorch model from which to get its module
module_name : str
the name of the required module
Returns
-------
module, module
the parent module of the required module, the required module
"""
name_list = module_name.split(".")
for name in name_list[:-1]:
model = getattr(model, name)
leaf_module = getattr(model, name_list[-1])
return model, leaf_module

class ModelSpeedup:
"""
This class is to speedup the model with provided weight mask
Expand Down Expand Up @@ -87,7 +66,8 @@ def infer_module_mask(self, module_name, last_module, mask=None, in_shape=None,
if module_name in self.inferred_masks:
module_masks = self.inferred_masks[module_name]
else:
module_masks = ModuleMasks(module_name)
_, m = get_module_by_name(self.bound_model, module_name)
module_masks = ModuleMasks(module_name, m)
self.inferred_masks[module_name] = module_masks

m_type = self.torch_graph.name_to_node[module_name].op_type
Expand All @@ -98,7 +78,12 @@ def infer_module_mask(self, module_name, last_module, mask=None, in_shape=None,
raise RuntimeError(
"Has not supported infering input/output shape from mask for module/function: `{}`, {}"
.format(m_type, module_name))
input_cmask, output_cmask = infer_from_mask[m_type](module_masks, mask)
if m_type in ['Linear']:
input_cmask, output_cmask = infer_from_mask[m_type](
module_masks, mask, self.torch_graph.name_to_node[module_name].auxiliary
)
else:
input_cmask, output_cmask = infer_from_mask[m_type](module_masks, mask)
if in_shape is not None:
_logger.debug("in_shape is not None")
if not m_type in infer_from_inshape:
Expand All @@ -124,7 +109,10 @@ def infer_module_mask(self, module_name, last_module, mask=None, in_shape=None,
raise RuntimeError(
"Has not supported infering input shape from output shape for module/function: `{}`, {}"
.format(m_type, module_name))
input_cmask = infer_from_outshape[m_type](module_masks, out_shape)
if m_type in ['aten::view', 'aten::flatten', 'aten::mean', 'aten::reshape']:
input_cmask = infer_from_outshape[m_type](module_masks, out_shape, self.torch_graph.name_to_node[module_name].auxiliary)
else:
input_cmask = infer_from_outshape[m_type](module_masks, out_shape)

if input_cmask:
predecessors = self.torch_graph.find_predecessors(module_name)
Expand Down Expand Up @@ -178,7 +166,6 @@ def replace_compressed_modules(self):
else:
raise RuntimeError("Unsupported node type: {}".format(g_node.type))


def speedup_model(self):
"""
There are basically two steps:
Expand All @@ -187,8 +174,11 @@ def speedup_model(self):
"""
training = self.bound_model.training
_logger.info("start to speed up the model")

_logger.info("fix the mask conflict of the interdependent layers")
fix_mask_conflict(self.masks, self.bound_model, self.dummy_input)
_, conv_prune_dim = fix_mask_conflict(self.masks, self.bound_model, self.dummy_input)
set_conv_prune_dim(conv_prune_dim)

_logger.info("infer module masks...")
self.infer_modules_masks()
_logger.info("replace compressed modules...")
Expand Down
Loading

0 comments on commit 85a879e

Please sign in to comment.