From 7eb974c0a5025a20b273c84be8cc7d212f6b2cee Mon Sep 17 00:00:00 2001 From: Ningxin Zheng Date: Tue, 3 Aug 2021 01:14:10 +0000 Subject: [PATCH 1/2] Support Speedup for Slim Pruner. Slim pruner prunes the BN layers which is not supported in the mask_conflict utils, in this PR, we support the BN layers in mask_conflict so that Slim pruner is supported in speedup. --- .../pytorch/utils/mask_conflict.py | 37 ++++++++++++++++++- .../pytorch/utils/shape_dependency.py | 23 +++++++++--- 2 files changed, 54 insertions(+), 6 deletions(-) diff --git a/nni/compression/pytorch/utils/mask_conflict.py b/nni/compression/pytorch/utils/mask_conflict.py index b797d61f25..de5c9586a6 100644 --- a/nni/compression/pytorch/utils/mask_conflict.py +++ b/nni/compression/pytorch/utils/mask_conflict.py @@ -184,6 +184,7 @@ def __init__(self, masks, model=None, dummy_input=None, traced=None): super(ChannelMaskConflict, self).__init__( masks, model, dummy_input, traced) self.conv_prune_dim = detect_mask_prune_dim(masks, model) + self.channel_prune_type = detect_channel_prune_type(masks, model) _logger.info('Dectected conv prune dim" %d', self.conv_prune_dim) def fix_mask(self): @@ -200,7 +201,8 @@ def fix_mask(self): """ if self.conv_prune_dim == 0: channel_depen = ChannelDependency( - self.model, self.dummy_input, self.traced) + self.model, self.dummy_input, self.traced, self.channel_prune_type) + else: channel_depen = InputChannelDependency( self.model, self.dummy_input, self.traced) @@ -307,10 +309,43 @@ def fix_mask(self): return self.masks +def detect_channel_prune_type(masks, model): + """ + User can prune a channel through two ways: 1) prune + the corresponding filter of the conv layer(all the + filter related pruner), 2) prune the BN layers that + followed after a conv(Slim pruner). This function find + the pruning type of the masks. + + Parameters + ---------- + masks: dict + A dict object that stores the masks. + model: nn.Module + Model object which the mask can be applied on. + Returns: + ------- + prune_type: str + Could be Filter or Batchnorm + """ + prune_type = 'Filter' + all_batch_norm = True + for layer_name in masks: + _, m = get_module_by_name(model, layer_name) + if m is None or (not isinstance(m, torch.nn.BatchNorm2d)): + all_batch_norm = False + break + if all_batch_norm: + # if all masks are for batchnorm layers, then the prune_type is BatchNorm + # Note, actually we currently do not support pruning both Conv and BatchNorm + # at the same time. + prune_type = 'Batchnorm' + return prune_type def detect_mask_prune_dim(masks, model): """ Detect how the masks of convolutional layers are pruned. + Parameters ---------- masks: dict diff --git a/nni/compression/pytorch/utils/shape_dependency.py b/nni/compression/pytorch/utils/shape_dependency.py index 883b731d57..645f3d1a20 100644 --- a/nni/compression/pytorch/utils/shape_dependency.py +++ b/nni/compression/pytorch/utils/shape_dependency.py @@ -85,7 +85,7 @@ def reshape_break_channel_dependency(op_node): class ChannelDependency(Dependency): - def __init__(self, model=None, dummy_input=None, traced_model=None): + def __init__(self, model=None, dummy_input=None, traced_model=None, prune_type='Filter'): """ This model analyze the channel dependencies between the conv layers in a model. @@ -98,7 +98,18 @@ def __init__(self, model=None, dummy_input=None, traced_model=None): traced_model : torch._C.Graph if we alreay has the traced graph of the target model, we donnot need to trace the model again. - """ + prune_type: str + This parameter indicates the channel pruning type: 1) `Filter` + prune the filter of the convolution layer to prune the corresponding + channels 2) prune the channel in the batchnorm layer + """ + self.prune_type = prune_type + self.target_types = [] + if self.prune_type == 'Filter': + self.target_types.extend(['Conv2d', 'Linear', 'ConvTranspose2d']) + elif self.prune_type == 'Batchnorm': + self.target_types.append('BatchNorm2d') + super(ChannelDependency, self).__init__( model, dummy_input, traced_model) @@ -114,12 +125,13 @@ def _get_parent_layers(self, node): parent_layers: list nearest father conv/linear layers for the target worknode. """ + parent_layers = [] queue = [] queue.append(node) while queue: curnode = queue.pop(0) - if curnode.op_type == 'Conv2d' or curnode.op_type == 'Linear' or curnode.op_type == 'ConvTranspose2d': + if curnode.op_type in self.target_types: # find the first met conv parent_layers.append(curnode.name) continue @@ -130,6 +142,7 @@ def _get_parent_layers(self, node): parents = [self.graph.name_to_node[name] for name in parents] for parent in parents: queue.append(parent) + return parent_layers def build_dependency(self): @@ -193,7 +206,7 @@ def export(self, filepath): csv_w = csv.writer(csvf, delimiter=',') csv_w.writerow(header) for node in self.graph.nodes_py.nodes_op: - if node.op_type != 'Conv2d' or node in visited: + if node.op_type not in self.target_types or node in visited: continue setid += 1 row = ['Set %d' % setid] @@ -220,7 +233,7 @@ def dependency_sets(self): d_sets = [] visited = set() for node in self.graph.nodes_py.nodes_op: - if (node.op_type != 'Conv2d' and node.op_type != 'Linear') or node in visited: + if node.op_type not in self.target_types or node in visited: continue tmp_set = set() if node.name not in self.dependency: From 7de7d11b86aedcc3b4ae26a481a6c3a91be76c56 Mon Sep 17 00:00:00 2001 From: Ningxin Zheng Date: Tue, 3 Aug 2021 11:12:36 +0000 Subject: [PATCH 2/2] resolve comments --- nni/compression/pytorch/utils/mask_conflict.py | 6 +----- nni/compression/pytorch/utils/shape_dependency.py | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/nni/compression/pytorch/utils/mask_conflict.py b/nni/compression/pytorch/utils/mask_conflict.py index de5c9586a6..240561fafa 100644 --- a/nni/compression/pytorch/utils/mask_conflict.py +++ b/nni/compression/pytorch/utils/mask_conflict.py @@ -188,11 +188,6 @@ def __init__(self, masks, model=None, dummy_input=None, traced=None): _logger.info('Dectected conv prune dim" %d', self.conv_prune_dim) def fix_mask(self): - """ - Fix the mask conflict before the mask inference for the layers that - has shape dependencies. This function should be called before the - mask inference of the 'speedup' module. - """ """ Fix the mask conflict before the mask inference for the layers that has shape dependencies. This function should be called before the @@ -323,6 +318,7 @@ def detect_channel_prune_type(masks, model): A dict object that stores the masks. model: nn.Module Model object which the mask can be applied on. + Returns: ------- prune_type: str diff --git a/nni/compression/pytorch/utils/shape_dependency.py b/nni/compression/pytorch/utils/shape_dependency.py index 645f3d1a20..1b92e5553d 100644 --- a/nni/compression/pytorch/utils/shape_dependency.py +++ b/nni/compression/pytorch/utils/shape_dependency.py @@ -101,7 +101,7 @@ def __init__(self, model=None, dummy_input=None, traced_model=None, prune_type=' prune_type: str This parameter indicates the channel pruning type: 1) `Filter` prune the filter of the convolution layer to prune the corresponding - channels 2) prune the channel in the batchnorm layer + channels 2) `Batchnorm`: prune the channel in the batchnorm layer """ self.prune_type = prune_type self.target_types = []