From 8bc74a2c9f87674990ace81818ecac2ea9016ac9 Mon Sep 17 00:00:00 2001 From: chicm-ms <38930155+chicm-ms@users.noreply.github.com> Date: Fri, 9 Oct 2020 16:42:35 +0800 Subject: [PATCH] Speedup supports channel pruning (#2906) --- src/sdk/pynni/nni/_graph_utils.py | 32 ++ .../nni/compression/torch/pruning/one_shot.py | 12 +- .../torch/speedup/compress_modules.py | 19 +- .../compression/torch/speedup/compressor.py | 46 +-- .../compression/torch/speedup/infer_shape.py | 335 +++++++++++++++--- .../compression/torch/utils/mask_conflict.py | 189 +++++++--- .../nni/compression/torch/utils/utils.py | 30 ++ src/sdk/pynni/tests/test_compression_utils.py | 2 +- src/sdk/pynni/tests/test_model_speedup.py | 118 +++++- 9 files changed, 639 insertions(+), 144 deletions(-) create mode 100644 src/sdk/pynni/nni/compression/torch/utils/utils.py diff --git a/src/sdk/pynni/nni/_graph_utils.py b/src/sdk/pynni/nni/_graph_utils.py index 3fa6cd0eab..0f18a91c8a 100644 --- a/src/sdk/pynni/nni/_graph_utils.py +++ b/src/sdk/pynni/nni/_graph_utils.py @@ -426,6 +426,36 @@ def _extract_cat_info(self, node_group, cpp_node): cat_info['in_shape'] = input_shapes return cat_info + def _extract_linear_shape_info(self, node_group): + """ + Extract linear shape input/output tensor shape info from its aten::addmm op. + + Parameters + ---------- + node_group : NodePyGroup + NodePyGroup object associated with the linear module. + + Returns + ------- + dict + Include shape of input tensor and shape of output tensor + """ + for cpp_node in node_group.node_cpps: + if cpp_node.kind() == 'aten::addmm': + # https://github.com/pytorch/pytorch/blob/1.6/torch/nn/functional.py#L1682 + # inputs of aten::addmm: + # inputs[0] is bias + # inputs[1] is input data + # inputs[2] is weight + t_input = list(cpp_node.inputs())[1] + t_output = cpp_node.output() + assert isinstance(t_input.type(), torch._C.TensorType) + assert isinstance(t_output.type(), torch._C.TensorType) + in_shape = t_input.type().sizes() + out_shape = t_output.type().sizes() + return {'in_shape': in_shape, 'out_shape': out_shape} + return None + def _extract_shape_info(self, node): """ Extract the shape information of ```aten::view``` node @@ -701,6 +731,8 @@ def _extract_auxiliary_info(self): cpp_node = list(filter(lambda x: x.kind() == node_group.op_type, node_group.node_cpps))[0] node_group.auxiliary = self._extract_shape_info(cpp_node) + elif node_group.op_type == 'Linear': + node_group.auxiliary = self._extract_linear_shape_info(node_group) elif node_group.op_type == CAT_KIND: # get the detail information for cat func cpp_node = list(filter(lambda x: x.kind() == node_group.op_type, diff --git a/src/sdk/pynni/nni/compression/torch/pruning/one_shot.py b/src/sdk/pynni/nni/compression/torch/pruning/one_shot.py index 3a096176d4..1958af91e2 100644 --- a/src/sdk/pynni/nni/compression/torch/pruning/one_shot.py +++ b/src/sdk/pynni/nni/compression/torch/pruning/one_shot.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import logging -from schema import And, Optional +from schema import And, Optional, SchemaError from nni._graph_utils import TorchModuleGraph from nni.compression.torch.utils.shape_dependency import ChannelDependency, GroupDependency from .constants import MASKER_DICT @@ -186,12 +186,16 @@ def update_mask(self): def validate_config(self, model, config_list): schema = CompressorSchema([{ - 'sparsity': And(float, lambda n: 0 < n < 1), - 'op_types': ['Conv2d'], - Optional('op_names'): [str] + Optional('sparsity'): And(float, lambda n: 0 < n < 1), + Optional('op_types'): ['Conv2d'], + Optional('op_names'): [str], + Optional('exclude'): bool }], model, logger) schema.validate(config_list) + for config in config_list: + if 'exclude' not in config and 'sparsity' not in config: + raise SchemaError('Either sparisty or exclude must be specified!') def _dependency_calc_mask(self, wrappers, channel_dsets, wrappers_idx=None): """ diff --git a/src/sdk/pynni/nni/compression/torch/speedup/compress_modules.py b/src/sdk/pynni/nni/compression/torch/speedup/compress_modules.py index 37d6a8e1e1..413f32c6df 100644 --- a/src/sdk/pynni/nni/compression/torch/speedup/compress_modules.py +++ b/src/sdk/pynni/nni/compression/torch/speedup/compress_modules.py @@ -116,15 +116,19 @@ def replace_conv2d(conv, mask): else: out_channels_index = mask.output_mask.mask_index[1] out_channels = out_channels_index.size()[0] - - _logger.debug("replace conv2d with in_channels: %d, out_channels: %d", in_channels, out_channels) + groups = conv.groups + if conv.in_channels == conv.out_channels == conv.groups: + # remove groups for depthwise layers + assert in_channels == out_channels + groups = in_channels + _logger.debug("replace conv2d %s with in_channels: %d, out_channels: %d", mask.module_name, in_channels, out_channels) new_conv = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=conv.kernel_size, stride=conv.stride, padding=conv.padding, dilation=conv.dilation, - groups=conv.groups, + groups=groups, bias=conv.bias is not None, padding_mode=conv.padding_mode) @@ -142,13 +146,16 @@ def replace_conv2d(conv, mask): # channal is also divided into serveral groups and each group # filter may have different input channel indexes. input_step = int(conv.in_channels / conv.groups) - in_channels_group = int(in_channels / conv.groups) - filter_step = int(out_channels / conv.groups) - if mask.input_mask is not None: + in_channels_group = int(in_channels / groups) + filter_step = int(out_channels / groups) + if mask.input_mask is not None and not (in_channels == out_channels == groups): for groupid in range(conv.groups): start = groupid * input_step end = (groupid + 1) * input_step current_input_index = list(filter(lambda x: start <= x and x < end, in_channels_index.tolist())) + if not current_input_index: + # there is no kept channel in current group + continue # shift the global index into the group index current_input_index = [x-start for x in current_input_index] # if the groups is larger than 1, the input channels of each diff --git a/src/sdk/pynni/nni/compression/torch/speedup/compressor.py b/src/sdk/pynni/nni/compression/torch/speedup/compressor.py index 41753e1c9f..a9ae7aef54 100644 --- a/src/sdk/pynni/nni/compression/torch/speedup/compressor.py +++ b/src/sdk/pynni/nni/compression/torch/speedup/compressor.py @@ -4,34 +4,13 @@ import logging import torch from nni.compression.torch.utils.mask_conflict import fix_mask_conflict +from nni.compression.torch.utils.utils import get_module_by_name from .compress_modules import replace_module -from .infer_shape import ModuleMasks, infer_from_mask, infer_from_inshape, infer_from_outshape +from .infer_shape import ModuleMasks, infer_from_mask, infer_from_inshape, infer_from_outshape, set_conv_prune_dim _logger = logging.getLogger(__name__) -def get_module_by_name(model, module_name): - """ - Get a module specified by its module name - - Parameters - ---------- - model : pytorch model - the pytorch model from which to get its module - module_name : str - the name of the required module - - Returns - ------- - module, module - the parent module of the required module, the required module - """ - name_list = module_name.split(".") - for name in name_list[:-1]: - model = getattr(model, name) - leaf_module = getattr(model, name_list[-1]) - return model, leaf_module - class ModelSpeedup: """ This class is to speedup the model with provided weight mask @@ -87,7 +66,8 @@ def infer_module_mask(self, module_name, last_module, mask=None, in_shape=None, if module_name in self.inferred_masks: module_masks = self.inferred_masks[module_name] else: - module_masks = ModuleMasks(module_name) + _, m = get_module_by_name(self.bound_model, module_name) + module_masks = ModuleMasks(module_name, m) self.inferred_masks[module_name] = module_masks m_type = self.torch_graph.name_to_node[module_name].op_type @@ -98,7 +78,12 @@ def infer_module_mask(self, module_name, last_module, mask=None, in_shape=None, raise RuntimeError( "Has not supported infering input/output shape from mask for module/function: `{}`, {}" .format(m_type, module_name)) - input_cmask, output_cmask = infer_from_mask[m_type](module_masks, mask) + if m_type in ['Linear']: + input_cmask, output_cmask = infer_from_mask[m_type]( + module_masks, mask, self.torch_graph.name_to_node[module_name].auxiliary + ) + else: + input_cmask, output_cmask = infer_from_mask[m_type](module_masks, mask) if in_shape is not None: _logger.debug("in_shape is not None") if not m_type in infer_from_inshape: @@ -124,7 +109,10 @@ def infer_module_mask(self, module_name, last_module, mask=None, in_shape=None, raise RuntimeError( "Has not supported infering input shape from output shape for module/function: `{}`, {}" .format(m_type, module_name)) - input_cmask = infer_from_outshape[m_type](module_masks, out_shape) + if m_type in ['aten::view', 'aten::flatten', 'aten::mean', 'aten::reshape']: + input_cmask = infer_from_outshape[m_type](module_masks, out_shape, self.torch_graph.name_to_node[module_name].auxiliary) + else: + input_cmask = infer_from_outshape[m_type](module_masks, out_shape) if input_cmask: predecessors = self.torch_graph.find_predecessors(module_name) @@ -178,7 +166,6 @@ def replace_compressed_modules(self): else: raise RuntimeError("Unsupported node type: {}".format(g_node.type)) - def speedup_model(self): """ There are basically two steps: @@ -187,8 +174,11 @@ def speedup_model(self): """ training = self.bound_model.training _logger.info("start to speed up the model") + _logger.info("fix the mask conflict of the interdependent layers") - fix_mask_conflict(self.masks, self.bound_model, self.dummy_input) + _, conv_prune_dim = fix_mask_conflict(self.masks, self.bound_model, self.dummy_input) + set_conv_prune_dim(conv_prune_dim) + _logger.info("infer module masks...") self.infer_modules_masks() _logger.info("replace compressed modules...") diff --git a/src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py b/src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py index 368518f3cf..5d636c8784 100644 --- a/src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py +++ b/src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py @@ -6,8 +6,22 @@ The other is given input shape, infer its output shape and initialization parameters (e.g., weight's shape) """ +import logging import torch +_logger = logging.getLogger(__name__) + +conv_prune_dim = -1 + +def set_conv_prune_dim(dim): + """ + Parameters: + dim: int + 0: filter pruning + 1: channel pruning + """ + global conv_prune_dim + conv_prune_dim = dim class CoarseMask: """ @@ -160,7 +174,7 @@ class ModuleMasks: The masks of a module, including the masks for weights, inputs, output """ - def __init__(self, module_name): + def __init__(self, module_name, module=None): """ Parameters ---------- @@ -168,6 +182,7 @@ def __init__(self, module_name): The name of the module or function """ self.module_name = module_name + self.module = module self.param_masks = dict() self.input_mask = None self.output_mask = None @@ -202,8 +217,8 @@ def set_output_mask(self, mask): self.output_mask = mask def __repr__(self): - return 'input_mask: {}, output_mask: {}, param_masks: {}'.format( - self.input_mask, self.output_mask, self.param_masks + return 'module_name: {}, input_mask: {}, output_mask: {}, param_masks: {}'.format( + self.module_name, self.input_mask, self.output_mask, self.param_masks ) @@ -212,7 +227,8 @@ def __repr__(self): """ infer_from_mask = { 'BatchNorm2d': lambda module_masks, mask: batchnorm2d_mask(module_masks, mask), - 'Conv2d': lambda module_masks, mask: conv2d_mask(module_masks, mask) + 'Conv2d': lambda module_masks, mask: conv2d_mask(module_masks, mask), + 'Linear': lambda module_masks, mask, shape: linear_mask(module_masks, mask, shape) } """ @@ -260,7 +276,34 @@ def __repr__(self): Infer input and weight shape of a module/function from its output shape """ infer_from_outshape = { - 'Conv2d': lambda module_masks, mask: conv2d_outshape(module_masks, mask) + 'Conv2d': lambda module_masks, mask: conv2d_outshape(module_masks, mask), + 'BatchNorm2d': lambda module_masks, mask: batchnorm2d_outshape(module_masks, mask), + + 'MaxPool2d': lambda module_masks, mask: maxpool2d_outshape(module_masks, mask), + 'aten::max_pool2d': lambda module_masks, mask: maxpool2d_outshape(module_masks, mask), + 'aten::avg_pool2d': lambda module_masks, mask: maxpool2d_outshape(module_masks, mask), + 'aten::adaptive_avg_pool2d': lambda module_masks, mask: maxpool2d_outshape(module_masks, mask), + 'AvgPool2d': lambda module_masks, mask: maxpool2d_outshape(module_masks, mask), + 'AdaptiveAvgPool2d': lambda module_masks, mask: maxpool2d_outshape(module_masks, mask), + + 'ReLU': lambda module_masks, mask: relu_outshape(module_masks, mask), + 'ReLU6': lambda module_masks, mask: relu_outshape(module_masks, mask), + 'aten::relu': lambda module_masks, mask: relu_outshape(module_masks, mask), + 'aten::tanh': lambda module_masks, mask: relu_outshape(module_masks, mask), + 'aten::tanh_': lambda module_masks, mask: relu_outshape(module_masks, mask), + 'aten::hardtanh': lambda module_masks, mask: relu_outshape(module_masks, mask), + 'aten::hardtanh_': lambda module_masks, mask: relu_outshape(module_masks, mask), + 'aten::relu_': lambda module_masks, mask: relu_outshape(module_masks, mask), + + 'aten::add_': lambda module_masks, mask: add_outshape(module_masks, mask), + 'aten::add': lambda module_mask, mask: add_outshape(module_mask, mask), + 'aten::flatten': lambda module_mask, mask, shape: view_outshape(module_mask, mask, shape), + 'aten::view': lambda module_masks, mask, shape: view_outshape(module_masks, mask, shape), + 'aten::reshape': lambda module_masks, mask, shape: view_outshape(module_masks, mask, shape), + 'aten::mean': lambda module_masks, mask, shape: mean_outshape(module_masks, mask, shape), + 'Dropout': lambda module_masks, mask: dropout_outshape(module_masks, mask), + 'Dropout2d': lambda module_masks, mask: dropout_outshape(module_masks, mask), + 'aten::dropout': lambda module_masks, mask: dropout_outshape(module_masks, mask) } def dropout_inshape(module_masks, mask): @@ -282,7 +325,15 @@ def dropout_inshape(module_masks, mask): module_masks.set_output_mask(mask) return module_masks.output_mask +def dropout_outshape(module_masks, mask): + if module_masks.output_mask is None: + module_masks.set_output_mask(mask) + module_masks.set_input_mask(mask) + return module_masks.input_mask + # if alreay visited + assert all(module_masks.output_mask.mask_index[1] == mask.mask_index[1]) + return module_masks.output_mask def cat_inshape(module_masks, mask, cat_info, last_visited): """ @@ -382,6 +433,20 @@ def add_inshape(module_masks, mask): raise Exception('Mask conflict happenes!') return None +def add_outshape(module_masks, mask): + """ + Inference the input mask of the add operation from the + output mask. + """ + assert isinstance(mask, CoarseMask) + + if module_masks.output_mask is None: + module_masks.set_output_mask(mask) + module_masks.set_input_mask(mask) + return mask + else: + assert all(module_masks.output_mask.mask_index[1] == mask.mask_index[1]) + return mask def batchnorm2d_inshape(module_masks, mask): """ @@ -412,6 +477,34 @@ def batchnorm2d_inshape(module_masks, mask): module_masks.set_param_masks('bias', weight_cmask) return mask +def batchnorm2d_outshape(module_masks, mask): + """ + We assume only the second dimension has coarse grained mask + + Parameters + ---------- + module_masks : ModuleMasks + The ModuleMasks instance of the batchnorm2d + mask : CoarseMask + The mask of its input tensor + + Returns + ------- + CoarseMask + The mask of its output tensor + """ + assert isinstance(mask, CoarseMask) + assert len(mask.mask_index) in [2, 4] + assert mask.mask_index[1] is not None + assert mask.mask_index[0] is None + module_masks.set_input_mask(mask) + module_masks.set_output_mask(mask) + weight_cmask = CoarseMask(num_dim=1) + weight_cmask.add_index_mask(dim=0, index=mask.mask_index[1]) + module_masks.set_param_masks('weight', weight_cmask) + module_masks.set_param_masks('bias', weight_cmask) + return mask + def linear_inshape(module_masks, mask): """ @@ -484,6 +577,42 @@ def view_inshape(module_masks, mask, shape): module_masks.set_output_mask(output_cmask) return output_cmask +def view_outshape(module_masks, mask, shape): + """ + Parameters + ---------- + module_masks : ModuleMasks + The ModuleMasks instance of the ```flatten``` op + mask : CoarseMask + The mask of its input tensor + shape : dict + Original shape of its input and output tensors + Returns + ------- + CoarseMask + The mask of its output tensor + """ + # NOTE: the case constrained by the following four asserts + assert shape['in_shape'][0] == shape['out_shape'][0] + assert len(shape['in_shape']) == 4 + assert len(shape['out_shape']) == 2 + assert shape['out_shape'][1] == shape['in_shape'][1] * \ + shape['in_shape'][2]*shape['in_shape'][3] + + assert isinstance(mask, CoarseMask) + assert mask.mask_index[1] is not None + assert mask.mask_index[0] is None + + module_masks.set_output_mask(mask) + input_cmask = CoarseMask(num_dim=4) + index = [] + step_size = shape['in_shape'][2] * shape['in_shape'][3] + for loc in mask.mask_index[1]: + index.extend([loc * step_size + i for i in range(step_size)]) + input_cmask.add_index_mask(dim=1, index=torch.tensor(index)) # pylint: disable=not-callable + module_masks.set_input_mask(input_cmask) + + return input_cmask def size_inshape(module_masks, mask): """ @@ -513,6 +642,26 @@ def mean_inshape(module_masks, mask, shape): module_masks.set_output_mask(output_cmask) return output_cmask +def mean_outshape(module_masks, mask, shape): + """ + Similar to view operation, currently mask inference only supports + the mean operation on the 3rd and 4th dimensions. + """ + assert shape['in_shape'][0] == shape['out_shape'][0] + assert shape['out_shape'][1] == shape['in_shape'][1] + assert len(shape['in_shape']) == 4 + assert len(shape['out_shape']) == 2 + + assert isinstance(mask, CoarseMask) + assert mask.mask_index[1] is not None + assert mask.mask_index[0] is None + module_masks.set_output_mask(mask) + + input_cmask = CoarseMask(num_dim=4) + input_cmask.add_index_mask(dim=1, index=mask.mask_index[1]) + module_masks.set_input_mask(input_cmask) + return input_cmask + def maxpool2d_inshape(module_masks, mask): """ Assume only the second dimension is masked @@ -541,6 +690,29 @@ def maxpool2d_inshape(module_masks, mask): module_masks.set_output_mask(mask) return mask +def maxpool2d_outshape(module_masks, mask): + """ + Assume only the second dimension is masked + + Parameters + ---------- + module_masks : ModuleMasks + The ModuleMasks instance of the maxpool2d + mask : CoarseMask + The mask of its input tensor + + Returns + ------- + CoarseMask + The mask of its output tensor + """ + assert isinstance(mask, CoarseMask) + assert mask.mask_index[1] is not None + assert mask.mask_index[0] is None + + module_masks.set_input_mask(mask) + module_masks.set_output_mask(mask) + return mask def relu_inshape(module_masks, mask): """ @@ -558,25 +730,44 @@ def relu_inshape(module_masks, mask): """ assert isinstance(mask, CoarseMask) if module_masks.input_mask is not None: - # check if has a mask conflict + # mask conflict should be solved before speedup assert module_masks.input_mask <= mask # assert module_masks.input_mask is None, "A relu op can only be processed once" module_masks.set_input_mask(mask) module_masks.set_output_mask(mask) return mask +def relu_outshape(module_masks, mask): + """ + Parameters + ---------- + module_masks : ModuleMasks + The ModuleMasks instance of the relu + mask : CoarseMask + The mask of its input tensor + + Returns + ------- + CoarseMask + The mask of its output tensor + """ + assert isinstance(mask, CoarseMask) + if module_masks.output_mask is not None: + # mask conflict should be solved before speedup + assert all(module_masks.output_mask.mask_index[1] == mask.mask_index[1]) + module_masks.set_input_mask(mask) + module_masks.set_output_mask(mask) + return mask def batchnorm2d_mask(module_masks, mask): """ Infer input and output shape from weight mask - Parameters ---------- module_masks : ModuleMasks The ModuleMasks instance of the batchnorm2d mask : dict The mask of its weights, from the user provided mask file - Returns ------- CoarseMask, CoarseMask @@ -601,6 +792,38 @@ def batchnorm2d_mask(module_masks, mask): module_masks.set_output_mask(output_cmask) return input_cmask, output_cmask +def linear_mask(module_masks, mask, shape): + """ + Infer input and output shape from weight mask with limitations: + Only support infer input mask + + Parameters + ---------- + module_masks : ModuleMasks + The ModuleMasks instance of the Linear + mask : dict + The mask of its weights, from the user provided mask file + shape: dict + Shape of its input and output tensors + Returns + ------- + CoarseMask, CoarseMask + The mask of its input tensor, the mask of its output tensor + """ + + assert 'weight' in mask + num_input_dim = len(shape['in_shape']) + + # Input data of Linear module can have multiple dimensions. + # here we only support infer coarse mask on the first dimension (dimension 0) + nonzero_index = torch.nonzero(mask['weight'].sum(0), as_tuple=True)[0] + + # infer shape of input tensor + input_cmask = CoarseMask(num_dim=num_input_dim) + input_cmask.add_index_mask(dim=num_input_dim-1, index=nonzero_index) + + module_masks.set_input_mask(input_cmask) + return input_cmask, None def conv2d_mask(module_masks, mask): """ @@ -618,12 +841,15 @@ def conv2d_mask(module_masks, mask): CoarseMask, CoarseMask The mask of its input tensor, the mask of its output tensor """ - def convert_to_coarse_mask(mask): + def convert_to_coarse_mask(mask, dim=0): """ Parameters ---------- mask : dict Weight mask from user provided mask file + dim: int + 0: filter pruning + 1: channel pruning Returns ------- @@ -632,64 +858,69 @@ def convert_to_coarse_mask(mask): """ assert 'weight' in mask assert isinstance(mask['weight'], torch.Tensor) + assert dim in [0, 1] + weight_mask = mask['weight'] - shape = weight_mask.size() - ones = torch.ones(shape[1:]).to(weight_mask.device) - zeros = torch.zeros(shape[1:]).to(weight_mask.device) - index = [] - for i in range(shape[0]): - if torch.all(torch.eq(weight_mask[i], ones)): - index.append(i) - elif torch.all(torch.eq(weight_mask[i], zeros)): - continue - else: - index = None - break + + sum_idx = (1, 2, 3) if dim == 0 else (0, 2, 3) + index = torch.nonzero(weight_mask.abs().sum(sum_idx) != 0, as_tuple=True)[0] + if len(index) == weight_mask.shape[dim]: # full mask + index = None + if index is None: return None, None, None else: index = torch.LongTensor(index).to(weight_mask.device) weight_cmask = CoarseMask(num_dim=4) - weight_cmask.add_index_mask(dim=0, index=index) + weight_cmask.add_index_mask(dim=dim, index=index) bias_cmask = None - if 'bias' in mask and mask['bias'] is not None: + if dim == 0 and 'bias' in mask and mask['bias'] is not None: bias_index = torch.nonzero(mask['bias'], as_tuple=True)[0] assert torch.all(torch.eq(index, bias_index)), \ "bias mask should be consistent with weight mask" bias_cmask = CoarseMask(num_dim=1) bias_cmask.add_index_mask(dim=0, index=bias_index) return index, weight_cmask, bias_cmask - index, weight_cmask, bias_cmask = convert_to_coarse_mask(mask) + + index, weight_cmask, bias_cmask = convert_to_coarse_mask(mask, dim=conv_prune_dim) + if index is None: # TODO: fine grained mask speedup return None, None # deal with coarse grain mask + # mask conflict should be solved by fix_mask_conflict before speedup if 'weight' in module_masks.param_masks: - module_masks.param_masks['weight'].merge(weight_cmask) - module_masks.param_masks['bias'].merge(bias_cmask) + assert module_masks.param_masks['weight'] == weight_cmask else: module_masks.set_param_masks('weight', weight_cmask) - module_masks.set_param_masks('bias', bias_cmask) - output_cmask = CoarseMask(num_dim=4) - output_cmask.add_index_mask(dim=1, index=index) - if module_masks.output_mask is None: - module_masks.set_output_mask(output_cmask) - else: - module_masks.output_mask.merge(output_cmask) - return None, module_masks.output_mask + if conv_prune_dim == 0: + module_masks.set_param_masks('bias', bias_cmask) + io_cmask = CoarseMask(num_dim=4) + io_cmask.add_index_mask(dim=1, index=index) + + if conv_prune_dim == 0: + if module_masks.output_mask is None: + module_masks.set_output_mask(io_cmask) + else: + assert module_masks.output_mask == io_cmask + return None, module_masks.output_mask + else: + if module_masks.input_mask is None: + module_masks.set_input_mask(io_cmask) + else: + assert module_masks.input_mask == io_cmask + return module_masks.input_mask, None def conv2d_inshape(module_masks, mask): """ Shape change of input tensor does not affect the shape of its output tensor - Parameters ---------- module_masks : ModuleMasks The ModuleMasks instance of the conv2d mask : CoarseMask The mask of its input tensor - Returns ------- CoarseMask @@ -701,8 +932,15 @@ def conv2d_inshape(module_masks, mask): else: # the same conv layer may be accessed more # than once, such as a concat operation. - assert module_masks.input_mask <= mask - module_masks.input_mask.merge(mask) + # mask conflict should be solved by fix_mask_conflict before speedup + assert module_masks.input_mask == mask + + # shape changes pass through depths wise conv layers + m = module_masks.module + if m.in_channels == m.out_channels == m.groups: + module_masks.output_mask = mask + module_masks.input_mask = mask + return mask return None @@ -728,18 +966,25 @@ def conv2d_outshape(module_masks, mask): assert mask.mask_index[2] is None assert mask.mask_index[3] is None - if module_masks.output_mask is not None: - assert isinstance(module_masks.output_mask, CoarseMask) - # set shape of output - mask = module_masks.output_mask.merge(mask) - else: + if module_masks.output_mask is None: module_masks.output_mask = mask - # infer shape of parameters + else: + # mask conflict should be solved by fix_mask_conflict before speedup + # mask and module_masks.output_mask may have different number of dimensions + # since they could be passed by linear or conv2d + assert all(module_masks.output_mask.mask_index[1] == mask.mask_index[1]) + weight_cmask = CoarseMask(num_dim=4) weight_cmask.add_index_mask(dim=0, index=mask.mask_index[1]) bias_cmask = CoarseMask(num_dim=1) bias_cmask.add_index_mask(dim=0, index=mask.mask_index[1]) module_masks.set_param_masks('weight', weight_cmask) module_masks.set_param_masks('bias', bias_cmask) - # input shape is not changed + + # shape changes pass through depths wise conv layers + m = module_masks.module + if m.in_channels == m.out_channels == m.groups: + module_masks.output_mask = mask + module_masks.input_mask = mask + return mask return None diff --git a/src/sdk/pynni/nni/compression/torch/utils/mask_conflict.py b/src/sdk/pynni/nni/compression/torch/utils/mask_conflict.py index 3945a961df..723f66b8f2 100644 --- a/src/sdk/pynni/nni/compression/torch/utils/mask_conflict.py +++ b/src/sdk/pynni/nni/compression/torch/utils/mask_conflict.py @@ -4,9 +4,10 @@ import logging import torch import numpy as np -from .shape_dependency import ChannelDependency, GroupDependency, CatPaddingDependency +from .shape_dependency import ChannelDependency, GroupDependency, CatPaddingDependency, InputChannelDependency +from .utils import get_module_by_name # logging.basicConfig(level = logging.DEBUG) -_logger = logging.getLogger('FixMaskConflict') +_logger = logging.getLogger(__name__) def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None): """ @@ -45,7 +46,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None): masks = fix_channel_mask.fix_mask() padding_cat_mask = CatMaskPadding(masks, model, dummy_input, traced) masks = padding_cat_mask.fix_mask() - return masks + return masks, fix_channel_mask.conv_prune_dim class MaskFix: def __init__(self, masks, model=None, dummy_input=None, traced=None): @@ -221,74 +222,148 @@ def __init__(self, masks, model=None, dummy_input=None, traced=None): we donnot use the model and dummpy_input to get the trace graph. """ super(ChannelMaskConflict, self).__init__(masks, model, dummy_input, traced) + self.conv_prune_dim = detect_mask_prune_dim(masks, model) + _logger.info('detected conv prune dim: %s', self.conv_prune_dim) def fix_mask(self): """ Fix the mask conflict before the mask inference for the layers that has shape dependencies. This function should be called before the - mask inference of the 'speedup' module. + mask inference of the 'speedup' module. Only structured pruning masks + are supported. """ - channel_depen = ChannelDependency(self.model, self.dummy_input, self.traced) + if self.conv_prune_dim == 0: + channel_depen = ChannelDependency(self.model, self.dummy_input, self.traced) + else: + channel_depen = InputChannelDependency(self.model, self.dummy_input, self.traced) depen_sets = channel_depen.dependency_sets + sum_idx = (1, 2, 3) if self.conv_prune_dim == 0 else (0, 2, 3) for dset in depen_sets: - if len(dset) == 1: - # This layer has no channel dependency with other layers + if len(dset) <= 1: continue - channel_remain = set() + # channel_masks is a list, each element is None or a vector, for example: + # [[0, 1, 1, 0, 0], [0, 0, 1, 1, 0], None], None means no channel + # is pruned. + channel_masks = [] fine_grained = False - out_channels = None - # A flag that represents if all the layers in - # the dependency set are pruned - all_pruned = True for name in dset: - if name not in self.masks: - # this layer is not pruned - all_pruned = False - continue - w_mask = self.masks[name]['weight'] - if out_channels is None: - out_channels = w_mask.size(0) - shape = w_mask.size() - count = np.prod(shape[1:]) - all_ones = (w_mask.flatten(1).sum(-1) == count).nonzero().squeeze(1).tolist() - all_zeros = (w_mask.flatten(1).sum(-1) == 0).nonzero().squeeze(1).tolist() - if len(all_ones) + len(all_zeros) < w_mask.size(0): - # In fine-grained pruning, there is no need to check - # the shape conflict - _logger.info('Layers %s using fine-grained pruning', ','.join(dset)) - fine_grained = True - break - channel_remain.update(all_ones) - _logger.debug('Layer: %s ', name) - _logger.debug('Original pruned filters: %s', str(all_zeros)) - # Update the masks for the layers in the dependency set - if fine_grained or out_channels is None: - # if use the fine-grained pruner or all the layers in - # this dependency set are not pruned + if name in self.masks: + _, m = get_module_by_name(self.model, name) + assert m is not None + mask = self.masks[name]['weight'] + if type(m).__name__ == 'Conv2d': + channel_mask = (mask.abs().sum(sum_idx) != 0).int() + channel_masks.append(channel_mask) + if (channel_mask.sum() * (mask.numel() / mask.shape[self.conv_prune_dim])).item() != (mask > 0).sum().item(): + fine_grained = True + elif type(m).__name__ == 'Linear': + channel_masks.append((mask.abs().sum(0) != 0).int()) + elif type(m).__name__ == 'BatchNorm2d': + channel_masks.append(mask.int()) + else: + raise RuntimeError(f'unsupported module type: {type(m).__name__}') + else: + # no mask means not pruned, equivlent to full masks + channel_masks.append(None) + if fine_grained: + _logger.info('fine-grained mask detected, skip solving conflict for this set: %s', dset) continue - if not all_pruned: - # if some layer are not pruned at all - # then all the layers in this dependency set - # cannot be pruned due to the shape dependency. - channel_remain.update(range(out_channels)) - ori_channels = 0 + if all(x is None for x in channel_masks): + continue + num_channels_list = [len(x) for x in channel_masks if x is not None] + # number of channels in same set should be identical + assert len(set(num_channels_list)) == 1 + num_channels = num_channels_list[0] + + for i, dim_mask in enumerate(channel_masks): + if dim_mask is None: + channel_masks[i] = torch.ones(num_channels).int() + + # merge masks with 'or' + merged_channel_mask = channel_masks[0].clone() + for i in range(1, len(channel_masks)): + merged_channel_mask = ((merged_channel_mask + channel_masks[i]) != 0).int() + + merged_index = torch.nonzero(merged_channel_mask, as_tuple=True)[0] + for name in dset: if name not in self.masks: - # this layer is not pruned at all - # in this case, all_pruned is False - # and the other layers in the same dset - # will not be pruned either. + assert all(merged_channel_mask) continue - mask = self.masks[name] - w_shape = mask['weight'].size() - ori_channels = w_shape[0] - for i in channel_remain: - mask['weight'][i] = torch.ones(w_shape[1:]) - if 'bias' in mask and mask['bias'] is not None: - mask['bias'][i] = 1 - _logger.info(','.join(dset)) - _logger.info('Pruned Filters after fixing conflict:') - pruned_filters = set(list(range(ori_channels)))-channel_remain - _logger.info(str(sorted(pruned_filters))) + orig_mask = self.masks[name]['weight'] + _, m = get_module_by_name(self.model, name) + new_mask = torch.zeros_like(orig_mask) + if type(m).__name__ == 'Conv2d': + if self.conv_prune_dim == 0: + new_mask[merged_index, :, :, :] = 1. + else: + new_mask[:, merged_index, :, :] = 1. + elif type(m).__name__ == 'Linear': + new_mask[:, merged_index] = 1. + elif type(m).__name__ == 'BatchNorm2d': + new_mask = merged_index.type_as(orig_mask) + else: + raise RuntimeError(f'unsupported module type: {type(m).__name__}') + + self.masks[name]['weight'] = new_mask + if 'bias' in self.masks[name] and self.masks[name]['bias'] is not None: + if type(m).__name__ == 'Conv2d': + assert self.conv_prune_dim == 0 + self.masks[name]['bias'] = merged_channel_mask.type_as(self.masks[name]['bias']) return self.masks + +def detect_mask_prune_dim(masks, model): + """ + Detect how the masks of convolutional layers are pruned. + + Parameters + ---------- + masks: dict + A dict object that stores the masks. + model: nn.Module + Model object which the mask can be applied on. + + Returns: + ------- + How the masks of convolutional layers are pruned, this depends on pruning algorithms, it should + return 1 for masks generated by AMCPruner, and returns 0 for masks generated by the rest + NNI builtin pruners. + 0: filter pruning, prune filters of weights which causes channels of output feature maps are pruned. + 1: channel pruning, prune kernels corresponding to each input channels which causes channels of + input feature maps are pruned. + """ + dim0_preserved, dim1_preserved = 0., 0. + dim0_num, dim1_num = 0., 0. + for module_name in masks: + _, m = get_module_by_name(model, module_name) + if m is None or type(m).__name__ != 'Conv2d': + continue + + mask = masks[module_name]['weight'].clone() + assert (mask >= 0).sum() == mask.numel(), \ + "mask values should be greater than or equal to 0." + mask = (mask > 0).int() + mask = mask.view(mask.shape[0], mask.shape[1], -1) + dim0_mask = (mask.sum((1, 2)) > 0).int() + dim1_mask = (mask.sum((0, 2)) > 0).int() + dim0_preserved += dim0_mask.sum().item() + dim1_preserved += dim1_mask.sum().item() + dim0_num += len(dim0_mask) + dim1_num += len(dim1_mask) + + if dim0_num == 0 or dim1_num == 0: + _logger.warning('no multi-dimension masks found.') + return 0 + + dim0_sparsity, dim1_sparsity = 1. - dim0_preserved / dim0_num, 1. - dim1_preserved / dim1_num + _logger.info('dim0 sparsity: %f', dim0_sparsity) + _logger.info('dim1 sparsity: %f', dim1_sparsity) + + if dim0_sparsity == dim1_sparsity == 0.: + _logger.warning('nothing masked.') + + if dim0_sparsity > 0 and dim1_sparsity > 0: + _logger.warning('both dim0 and dim1 masks found.') + + return 0 if dim0_sparsity >= dim1_sparsity else 1 diff --git a/src/sdk/pynni/nni/compression/torch/utils/utils.py b/src/sdk/pynni/nni/compression/torch/utils/utils.py new file mode 100644 index 0000000000..c687c5e2a6 --- /dev/null +++ b/src/sdk/pynni/nni/compression/torch/utils/utils.py @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +def get_module_by_name(model, module_name): + """ + Get a module specified by its module name + + Parameters + ---------- + model : pytorch model + the pytorch model from which to get its module + module_name : str + the name of the required module + + Returns + ------- + module, module + the parent module of the required module, the required module + """ + name_list = module_name.split(".") + for name in name_list[:-1]: + if hasattr(model, name): + model = getattr(model, name) + else: + return None, None + if hasattr(model, name_list[-1]): + leaf_module = getattr(model, name_list[-1]) + return model, leaf_module + else: + return None, None diff --git a/src/sdk/pynni/tests/test_compression_utils.py b/src/sdk/pynni/tests/test_compression_utils.py index 90c88db573..1be121bbd1 100644 --- a/src/sdk/pynni/tests/test_compression_utils.py +++ b/src/sdk/pynni/tests/test_compression_utils.py @@ -115,7 +115,7 @@ def test_mask_conflict(self): pruner.export_model(ck_file, mask_file) pruner._unwrap_model() # Fix the mask conflict - fixed_mask = fix_mask_conflict(mask_file, net, dummy_input) + fixed_mask, _ = fix_mask_conflict(mask_file, net, dummy_input) # use the channel dependency groud truth to check if # fix the mask conflict successfully diff --git a/src/sdk/pynni/tests/test_model_speedup.py b/src/sdk/pynni/tests/test_model_speedup.py index d181c7a289..e61d2efd5c 100644 --- a/src/sdk/pynni/tests/test_model_speedup.py +++ b/src/sdk/pynni/tests/test_model_speedup.py @@ -12,6 +12,8 @@ from unittest import TestCase, main from nni.compression.torch import L1FilterPruner, apply_compression_results, ModelSpeedup +from nni.compression.torch.pruning.weight_masker import WeightMasker +from nni.compression.torch.pruning.one_shot import _StructuredFilterPruner torch.manual_seed(0) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -104,6 +106,74 @@ def zero_bn_bias(model): shape = module.running_mean.data.size() module.running_mean = torch.zeros(shape).to(device) +class L1ChannelMasker(WeightMasker): + def __init__(self, model, pruner): + self.model = model + self.pruner = pruner + + def calc_mask(self, sparsity, wrapper, wrapper_idx=None): + msg = 'module type {} is not supported!'.format(wrapper.type) + #assert wrapper.type == 'Conv2d', msg + weight = wrapper.module.weight.data + bias = None + if hasattr(wrapper.module, 'bias') and wrapper.module.bias is not None: + bias = wrapper.module.bias.data + + if wrapper.weight_mask is None: + mask_weight = torch.ones(weight.size()).type_as(weight).detach() + else: + mask_weight = wrapper.weight_mask.clone() + if bias is not None: + if wrapper.bias_mask is None: + mask_bias = torch.ones(bias.size()).type_as(bias).detach() + else: + mask_bias = wrapper.bias_mask.clone() + else: + mask_bias = None + base_mask = {'weight_mask': mask_weight, 'bias_mask': mask_bias} + + num_total = weight.size(1) + num_prune = int(num_total * sparsity) + + if num_total < 2 or num_prune < 1: + return base_mask + w_abs = weight.abs() + if wrapper.type == 'Conv2d': + w_abs_structured = w_abs.sum((0, 2, 3)) + threshold = torch.topk(w_abs_structured, num_prune, largest=False)[0].max() + mask_weight = torch.gt(w_abs_structured, threshold)[None, :, None, None].expand_as(weight).type_as(weight) + return {'weight_mask': mask_weight.detach()} + else: + # Linear + assert wrapper.type == 'Linear' + w_abs_structured = w_abs.sum((0)) + threshold = torch.topk(w_abs_structured, num_prune, largest=False)[0].max() + mask_weight = torch.gt(w_abs_structured, threshold)[None, :].expand_as(weight).type_as(weight) + return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias} + +class L1ChannelPruner(_StructuredFilterPruner): + def __init__(self, model, config_list, optimizer=None, dependency_aware=False, dummy_input=None): + super().__init__(model, config_list, pruning_algorithm='l1', optimizer=optimizer, + dependency_aware=dependency_aware, dummy_input=dummy_input) + def validate_config(self, model, config_list): + pass + + +def channel_prune(model): + config_list = [{ + 'sparsity': SPARSITY, + 'op_types': ['Conv2d', 'Linear'] + }, { + 'op_names': ['conv1'], + 'exclude': True + }] + + pruner = L1ChannelPruner(model, config_list) + masker = L1ChannelMasker(model, pruner) + pruner.masker = masker + pruner.compress() + pruner.export_model(model_path=MODEL_FILE, mask_path=MASK_FILE) + class SpeedupTestCase(TestCase): def test_speedup_vgg16(self): prune_model_l1(vgg16()) @@ -145,10 +215,20 @@ def test_speedup_bigmodel(self): assert model.backbone2.fc1.in_features == int(orig_model.backbone2.fc1.in_features * SPARSITY) def test_speedup_integration(self): - for model_name in ['resnet18', 'squeezenet1_1', 'mobilenet_v2', 'densenet121', 'densenet169', 'inception_v3']: + for model_name in ['resnet18', 'squeezenet1_1', 'mobilenet_v2', 'densenet121', 'densenet169', 'inception_v3', 'resnet50']: + kwargs = { + 'pretrained': True + } + if model_name == 'resnet50': + # testing multiple groups + kwargs = { + 'pretrained': False, + 'groups': 4 + } + Model = getattr(models, model_name) - net = Model(pretrained=True, progress=False).to(device) - speedup_model = Model().to(device) + net = Model(**kwargs).to(device) + speedup_model = Model(**kwargs).to(device) net.eval() # this line is necessary speedup_model.eval() # random generate the prune config for the pruner @@ -165,6 +245,9 @@ def test_speedup_integration(self): data = torch.ones(BATCH_SIZE, 3, 224, 224).to(device) ms = ModelSpeedup(speedup_model, data, MASK_FILE) ms.speedup_model() + + speedup_model.eval() + ori_out = net(data) speeded_out = speedup_model(data) ori_sum = torch.sum(ori_out).item() @@ -174,6 +257,35 @@ def test_speedup_integration(self): assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \ (abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD) + def test_channel_prune(self): + orig_net = resnet18(num_classes=10).to(device) + channel_prune(orig_net) + state_dict = torch.load(MODEL_FILE) + + orig_net = resnet18(num_classes=10).to(device) + orig_net.load_state_dict(state_dict) + apply_compression_results(orig_net, MASK_FILE) + orig_net.eval() + + net = resnet18(num_classes=10).to(device) + + net.load_state_dict(state_dict) + net.eval() + + data = torch.randn(BATCH_SIZE, 3, 224, 224).to(device) + ms = ModelSpeedup(net, data, MASK_FILE) + ms.speedup_model() + ms.bound_model(data) + + net.eval() + + ori_sum = orig_net(data).abs().sum().item() + speeded_sum = net(data).abs().sum().item() + + print(ori_sum, speeded_sum) + assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \ + (abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD) + def tearDown(self): os.remove(MODEL_FILE) os.remove(MASK_FILE)