Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Speedup supports channel pruning #2906

Merged
merged 119 commits into from
Oct 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
119 commits
Select commit Hold shift + click to select a range
3a45961
Merge pull request #31 from microsoft/master
chicm-ms Aug 6, 2019
633db43
Merge pull request #32 from microsoft/master
chicm-ms Sep 9, 2019
3e926f1
Merge pull request #33 from microsoft/master
chicm-ms Oct 8, 2019
f173789
Merge pull request #34 from microsoft/master
chicm-ms Oct 9, 2019
508850a
Merge pull request #35 from microsoft/master
chicm-ms Oct 9, 2019
5a0e9c9
Merge pull request #36 from microsoft/master
chicm-ms Oct 10, 2019
e7df061
Merge pull request #37 from microsoft/master
chicm-ms Oct 23, 2019
2175cef
Merge pull request #38 from microsoft/master
chicm-ms Oct 29, 2019
2ccbfbb
Merge pull request #39 from microsoft/master
chicm-ms Oct 30, 2019
b29cb0b
Merge pull request #40 from microsoft/master
chicm-ms Oct 30, 2019
4a3ba83
Merge pull request #41 from microsoft/master
chicm-ms Nov 4, 2019
c8a1148
Merge pull request #42 from microsoft/master
chicm-ms Nov 4, 2019
73c6101
Merge pull request #43 from microsoft/master
chicm-ms Nov 5, 2019
6a518a9
Merge pull request #44 from microsoft/master
chicm-ms Nov 11, 2019
a0d587f
Merge pull request #45 from microsoft/master
chicm-ms Nov 12, 2019
e905bfe
Merge pull request #46 from microsoft/master
chicm-ms Nov 14, 2019
4b266f3
Merge pull request #47 from microsoft/master
chicm-ms Nov 15, 2019
237ff4b
Merge pull request #48 from microsoft/master
chicm-ms Nov 21, 2019
682be01
Merge pull request #49 from microsoft/master
chicm-ms Nov 25, 2019
133af82
Merge pull request #50 from microsoft/master
chicm-ms Nov 25, 2019
71a8a25
Merge pull request #51 from microsoft/master
chicm-ms Nov 26, 2019
d2a73bc
Merge pull request #52 from microsoft/master
chicm-ms Nov 26, 2019
198cf5e
Merge pull request #53 from microsoft/master
chicm-ms Dec 5, 2019
cdbfaf9
Merge pull request #54 from microsoft/master
chicm-ms Dec 6, 2019
7e9b29e
Merge pull request #55 from microsoft/master
chicm-ms Dec 10, 2019
d00c46d
Merge pull request #56 from microsoft/master
chicm-ms Dec 10, 2019
de7d1fa
Merge pull request #57 from microsoft/master
chicm-ms Dec 11, 2019
1835ab0
Merge pull request #58 from microsoft/master
chicm-ms Dec 12, 2019
24fead6
Merge pull request #59 from microsoft/master
chicm-ms Dec 20, 2019
0b7321e
Merge pull request #60 from microsoft/master
chicm-ms Dec 23, 2019
60058d4
Merge pull request #61 from microsoft/master
chicm-ms Dec 23, 2019
b111a55
Merge pull request #62 from microsoft/master
chicm-ms Dec 24, 2019
611c337
Merge pull request #63 from microsoft/master
chicm-ms Dec 30, 2019
4a1f14a
Merge pull request #64 from microsoft/master
chicm-ms Jan 10, 2020
7a9e604
Merge pull request #65 from microsoft/master
chicm-ms Jan 14, 2020
b8035b0
Merge pull request #66 from microsoft/master
chicm-ms Feb 4, 2020
47567d3
Merge pull request #67 from microsoft/master
chicm-ms Feb 10, 2020
614d427
Merge pull request #68 from microsoft/master
chicm-ms Feb 10, 2020
a0d9ed6
Merge pull request #69 from microsoft/master
chicm-ms Feb 11, 2020
22dc1ad
Merge pull request #70 from microsoft/master
chicm-ms Feb 19, 2020
0856813
Merge pull request #71 from microsoft/master
chicm-ms Feb 22, 2020
9e97bed
Merge pull request #72 from microsoft/master
chicm-ms Feb 25, 2020
16a1b27
Merge pull request #73 from microsoft/master
chicm-ms Mar 3, 2020
e246633
Merge pull request #74 from microsoft/master
chicm-ms Mar 4, 2020
0439bc1
Merge pull request #75 from microsoft/master
chicm-ms Mar 17, 2020
8b5613a
Merge pull request #76 from microsoft/master
chicm-ms Mar 18, 2020
43e8d31
Merge pull request #77 from microsoft/master
chicm-ms Mar 22, 2020
aae448e
Merge pull request #78 from microsoft/master
chicm-ms Mar 25, 2020
7095716
Merge pull request #79 from microsoft/master
chicm-ms Mar 25, 2020
c51263a
Merge pull request #80 from microsoft/master
chicm-ms Apr 11, 2020
9953c70
Merge pull request #81 from microsoft/master
chicm-ms Apr 14, 2020
f9136c4
Merge pull request #82 from microsoft/master
chicm-ms Apr 16, 2020
b384ad2
Merge pull request #83 from microsoft/master
chicm-ms Apr 20, 2020
ff592dd
Merge pull request #84 from microsoft/master
chicm-ms May 12, 2020
0b5378f
Merge pull request #85 from microsoft/master
chicm-ms May 18, 2020
a53e0b0
Merge pull request #86 from microsoft/master
chicm-ms May 25, 2020
3ea0b89
Merge pull request #87 from microsoft/master
chicm-ms May 28, 2020
cf3fb20
Merge pull request #88 from microsoft/master
chicm-ms May 28, 2020
7f4cdcd
Merge pull request #89 from microsoft/master
chicm-ms Jun 4, 2020
574db2c
Merge pull request #90 from microsoft/master
chicm-ms Jun 15, 2020
32bedcc
Merge pull request #91 from microsoft/master
chicm-ms Jun 21, 2020
6155aa4
Merge pull request #92 from microsoft/master
chicm-ms Jun 22, 2020
8139c9c
Merge pull request #93 from microsoft/master
chicm-ms Jun 23, 2020
43419d7
Merge pull request #94 from microsoft/master
chicm-ms Jun 28, 2020
6b6ee55
Merge pull request #95 from microsoft/master
chicm-ms Jun 28, 2020
1b975e0
Merge pull request #96 from microsoft/master
chicm-ms Jun 28, 2020
c8f3c5d
Merge pull request #97 from microsoft/master
chicm-ms Jun 29, 2020
4c306f0
Merge pull request #98 from microsoft/master
chicm-ms Jun 30, 2020
64de4c2
Merge pull request #99 from microsoft/master
chicm-ms Jun 30, 2020
0e5d3ac
Merge pull request #100 from microsoft/master
chicm-ms Jul 1, 2020
4a52608
Merge pull request #101 from microsoft/master
chicm-ms Jul 3, 2020
208b1ee
Merge pull request #102 from microsoft/master
chicm-ms Jul 8, 2020
e7b1a2e
Merge pull request #103 from microsoft/master
chicm-ms Jul 10, 2020
57bcc85
Merge pull request #104 from microsoft/master
chicm-ms Jul 22, 2020
030f5ef
Merge pull request #105 from microsoft/master
chicm-ms Jul 29, 2020
058c8b7
Merge pull request #106 from microsoft/master
chicm-ms Aug 2, 2020
9abd8c8
Merge pull request #107 from microsoft/master
chicm-ms Aug 10, 2020
13c6623
Merge pull request #108 from microsoft/master
chicm-ms Aug 11, 2020
b50b41e
Merge pull request #109 from microsoft/master
chicm-ms Aug 12, 2020
78f1418
Merge pull request #110 from microsoft/master
chicm-ms Aug 13, 2020
74acc8b
Merge pull request #111 from microsoft/master
chicm-ms Aug 17, 2020
5bf416a
Merge pull request #112 from microsoft/master
chicm-ms Aug 24, 2020
4a207f9
Merge pull request #113 from microsoft/master
chicm-ms Sep 3, 2020
7be897b
Merge pull request #114 from microsoft/master
chicm-ms Sep 16, 2020
f974b2c
Merge pull request #115 from microsoft/master
chicm-ms Sep 17, 2020
4488ec1
Speedup support channel pruning
chicm-ms Sep 17, 2020
d29bfc9
Merge branch 'master' into speedup-channel-pruning
chicm-ms Sep 17, 2020
9f2464a
updates
chicm-ms Sep 17, 2020
5fe44b8
updates
chicm-ms Sep 17, 2020
b7854c6
updates
chicm-ms Sep 18, 2020
d144e6a
fix conv groups
chicm-ms Sep 18, 2020
0b7712d
auto detect prune dim
chicm-ms Sep 18, 2020
ff7822a
updates
chicm-ms Sep 18, 2020
f6cf13d
updates
chicm-ms Sep 21, 2020
0882238
updates
chicm-ms Sep 21, 2020
e6c5314
updates
chicm-ms Sep 21, 2020
0c2f59b
Merge pull request #116 from microsoft/master
chicm-ms Sep 21, 2020
3c9de75
Merge branch 'master' into speedup-channel-pruning
chicm-ms Sep 21, 2020
c5e14f9
updates
chicm-ms Sep 21, 2020
d0f6ce9
updates
chicm-ms Sep 22, 2020
2aa761a
test multiple groups
chicm-ms Sep 23, 2020
38b2628
updates
chicm-ms Sep 23, 2020
4be03fc
updates
chicm-ms Sep 23, 2020
36f3aad
updates
chicm-ms Sep 23, 2020
081b280
updates
chicm-ms Sep 24, 2020
72ad87c
updates
chicm-ms Sep 24, 2020
891feb2
updates
chicm-ms Sep 25, 2020
1ae1691
updates
chicm-ms Sep 25, 2020
fa5f92d
updates
chicm-ms Sep 25, 2020
3c5cef2
Merge pull request #117 from microsoft/master
chicm-ms Sep 25, 2020
978d3c1
updates
chicm-ms Sep 25, 2020
50cf8d4
add linear shape
chicm-ms Sep 27, 2020
e0c96d8
updates
chicm-ms Sep 27, 2020
986f9c7
updates
chicm-ms Sep 27, 2020
8c382dc
updates
chicm-ms Sep 27, 2020
3a50e03
updates
chicm-ms Sep 27, 2020
fe5602a
updates
chicm-ms Sep 28, 2020
1249e82
channel pruning correctness testing
chicm-ms Oct 9, 2020
76da0ea
updates
chicm-ms Oct 9, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/sdk/pynni/nni/_graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,36 @@ def _extract_cat_info(self, node_group, cpp_node):
cat_info['in_shape'] = input_shapes
return cat_info

def _extract_linear_shape_info(self, node_group):
"""
Extract linear shape input/output tensor shape info from its aten::addmm op.

Parameters
----------
node_group : NodePyGroup
NodePyGroup object associated with the linear module.

Returns
-------
dict
Include shape of input tensor and shape of output tensor
"""
for cpp_node in node_group.node_cpps:
if cpp_node.kind() == 'aten::addmm':
# https://github.com/pytorch/pytorch/blob/1.6/torch/nn/functional.py#L1682
# inputs of aten::addmm:
# inputs[0] is bias
# inputs[1] is input data
# inputs[2] is weight
t_input = list(cpp_node.inputs())[1]
t_output = cpp_node.output()
assert isinstance(t_input.type(), torch._C.TensorType)
assert isinstance(t_output.type(), torch._C.TensorType)
in_shape = t_input.type().sizes()
out_shape = t_output.type().sizes()
return {'in_shape': in_shape, 'out_shape': out_shape}
return None

def _extract_shape_info(self, node):
"""
Extract the shape information of ```aten::view``` node
Expand Down Expand Up @@ -701,6 +731,8 @@ def _extract_auxiliary_info(self):
cpp_node = list(filter(lambda x: x.kind() == node_group.op_type,
node_group.node_cpps))[0]
node_group.auxiliary = self._extract_shape_info(cpp_node)
elif node_group.op_type == 'Linear':
node_group.auxiliary = self._extract_linear_shape_info(node_group)
elif node_group.op_type == CAT_KIND:
# get the detail information for cat func
cpp_node = list(filter(lambda x: x.kind() == node_group.op_type,
Expand Down
12 changes: 8 additions & 4 deletions src/sdk/pynni/nni/compression/torch/pruning/one_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT license.

import logging
from schema import And, Optional
from schema import And, Optional, SchemaError
from nni._graph_utils import TorchModuleGraph
from nni.compression.torch.utils.shape_dependency import ChannelDependency, GroupDependency
from .constants import MASKER_DICT
Expand Down Expand Up @@ -186,12 +186,16 @@ def update_mask(self):

def validate_config(self, model, config_list):
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
'op_types': ['Conv2d'],
Optional('op_names'): [str]
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved
Optional('op_types'): ['Conv2d'],
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)

schema.validate(config_list)
for config in config_list:
if 'exclude' not in config and 'sparsity' not in config:
raise SchemaError('Either sparisty or exclude must be specified!')

def _dependency_calc_mask(self, wrappers, channel_dsets, wrappers_idx=None):
"""
Expand Down
19 changes: 13 additions & 6 deletions src/sdk/pynni/nni/compression/torch/speedup/compress_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,19 @@ def replace_conv2d(conv, mask):
else:
out_channels_index = mask.output_mask.mask_index[1]
out_channels = out_channels_index.size()[0]

_logger.debug("replace conv2d with in_channels: %d, out_channels: %d", in_channels, out_channels)
groups = conv.groups
if conv.in_channels == conv.out_channels == conv.groups:
# remove groups for depthwise layers
assert in_channels == out_channels
groups = in_channels
_logger.debug("replace conv2d %s with in_channels: %d, out_channels: %d", mask.module_name, in_channels, out_channels)
new_conv = torch.nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
dilation=conv.dilation,
groups=conv.groups,
groups=groups,
bias=conv.bias is not None,
padding_mode=conv.padding_mode)

Expand All @@ -142,13 +146,16 @@ def replace_conv2d(conv, mask):
# channal is also divided into serveral groups and each group
# filter may have different input channel indexes.
input_step = int(conv.in_channels / conv.groups)
in_channels_group = int(in_channels / conv.groups)
filter_step = int(out_channels / conv.groups)
if mask.input_mask is not None:
in_channels_group = int(in_channels / groups)
filter_step = int(out_channels / groups)
if mask.input_mask is not None and not (in_channels == out_channels == groups):
for groupid in range(conv.groups):
start = groupid * input_step
end = (groupid + 1) * input_step
current_input_index = list(filter(lambda x: start <= x and x < end, in_channels_index.tolist()))
if not current_input_index:
# there is no kept channel in current group
continue
# shift the global index into the group index
current_input_index = [x-start for x in current_input_index]
# if the groups is larger than 1, the input channels of each
Expand Down
46 changes: 18 additions & 28 deletions src/sdk/pynni/nni/compression/torch/speedup/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,13 @@
import logging
import torch
from nni.compression.torch.utils.mask_conflict import fix_mask_conflict
from nni.compression.torch.utils.utils import get_module_by_name
from .compress_modules import replace_module
from .infer_shape import ModuleMasks, infer_from_mask, infer_from_inshape, infer_from_outshape
from .infer_shape import ModuleMasks, infer_from_mask, infer_from_inshape, infer_from_outshape, set_conv_prune_dim

_logger = logging.getLogger(__name__)


def get_module_by_name(model, module_name):
"""
Get a module specified by its module name

Parameters
----------
model : pytorch model
the pytorch model from which to get its module
module_name : str
the name of the required module

Returns
-------
module, module
the parent module of the required module, the required module
"""
name_list = module_name.split(".")
for name in name_list[:-1]:
model = getattr(model, name)
leaf_module = getattr(model, name_list[-1])
return model, leaf_module

class ModelSpeedup:
"""
This class is to speedup the model with provided weight mask
Expand Down Expand Up @@ -87,7 +66,8 @@ def infer_module_mask(self, module_name, last_module, mask=None, in_shape=None,
if module_name in self.inferred_masks:
module_masks = self.inferred_masks[module_name]
else:
module_masks = ModuleMasks(module_name)
_, m = get_module_by_name(self.bound_model, module_name)
module_masks = ModuleMasks(module_name, m)
self.inferred_masks[module_name] = module_masks

m_type = self.torch_graph.name_to_node[module_name].op_type
Expand All @@ -98,7 +78,12 @@ def infer_module_mask(self, module_name, last_module, mask=None, in_shape=None,
raise RuntimeError(
"Has not supported infering input/output shape from mask for module/function: `{}`, {}"
.format(m_type, module_name))
input_cmask, output_cmask = infer_from_mask[m_type](module_masks, mask)
if m_type in ['Linear']:
input_cmask, output_cmask = infer_from_mask[m_type](
module_masks, mask, self.torch_graph.name_to_node[module_name].auxiliary
)
else:
input_cmask, output_cmask = infer_from_mask[m_type](module_masks, mask)
if in_shape is not None:
_logger.debug("in_shape is not None")
if not m_type in infer_from_inshape:
Expand All @@ -124,7 +109,10 @@ def infer_module_mask(self, module_name, last_module, mask=None, in_shape=None,
raise RuntimeError(
"Has not supported infering input shape from output shape for module/function: `{}`, {}"
.format(m_type, module_name))
input_cmask = infer_from_outshape[m_type](module_masks, out_shape)
if m_type in ['aten::view', 'aten::flatten', 'aten::mean', 'aten::reshape']:
input_cmask = infer_from_outshape[m_type](module_masks, out_shape, self.torch_graph.name_to_node[module_name].auxiliary)
else:
input_cmask = infer_from_outshape[m_type](module_masks, out_shape)

if input_cmask:
predecessors = self.torch_graph.find_predecessors(module_name)
Expand Down Expand Up @@ -178,7 +166,6 @@ def replace_compressed_modules(self):
else:
raise RuntimeError("Unsupported node type: {}".format(g_node.type))


def speedup_model(self):
"""
There are basically two steps:
Expand All @@ -187,8 +174,11 @@ def speedup_model(self):
"""
training = self.bound_model.training
_logger.info("start to speed up the model")

_logger.info("fix the mask conflict of the interdependent layers")
fix_mask_conflict(self.masks, self.bound_model, self.dummy_input)
_, conv_prune_dim = fix_mask_conflict(self.masks, self.bound_model, self.dummy_input)
set_conv_prune_dim(conv_prune_dim)

_logger.info("infer module masks...")
self.infer_modules_masks()
_logger.info("replace compressed modules...")
Expand Down
Loading