Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Support Speedup for Slim Pruner. #4008

Merged
merged 2 commits into from
Aug 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions nni/compression/pytorch/utils/mask_conflict.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,10 @@ def __init__(self, masks, model=None, dummy_input=None, traced=None):
super(ChannelMaskConflict, self).__init__(
masks, model, dummy_input, traced)
self.conv_prune_dim = detect_mask_prune_dim(masks, model)
self.channel_prune_type = detect_channel_prune_type(masks, model)
_logger.info('Dectected conv prune dim" %d', self.conv_prune_dim)

def fix_mask(self):
"""
Fix the mask conflict before the mask inference for the layers that
has shape dependencies. This function should be called before the
mask inference of the 'speedup' module.
"""
"""
Fix the mask conflict before the mask inference for the layers that
has shape dependencies. This function should be called before the
Expand All @@ -200,7 +196,8 @@ def fix_mask(self):
"""
if self.conv_prune_dim == 0:
channel_depen = ChannelDependency(
self.model, self.dummy_input, self.traced)
self.model, self.dummy_input, self.traced, self.channel_prune_type)

else:
channel_depen = InputChannelDependency(
self.model, self.dummy_input, self.traced)
Expand Down Expand Up @@ -307,10 +304,44 @@ def fix_mask(self):

return self.masks

def detect_channel_prune_type(masks, model):
"""
User can prune a channel through two ways: 1) prune
the corresponding filter of the conv layer(all the
filter related pruner), 2) prune the BN layers that
followed after a conv(Slim pruner). This function find
the pruning type of the masks.

Parameters
----------
masks: dict
A dict object that stores the masks.
model: nn.Module
Model object which the mask can be applied on.

Returns:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add blank line

-------
prune_type: str
Could be Filter or Batchnorm
"""
prune_type = 'Filter'
all_batch_norm = True
for layer_name in masks:
_, m = get_module_by_name(model, layer_name)
if m is None or (not isinstance(m, torch.nn.BatchNorm2d)):
all_batch_norm = False
break
if all_batch_norm:
# if all masks are for batchnorm layers, then the prune_type is BatchNorm
# Note, actually we currently do not support pruning both Conv and BatchNorm
# at the same time.
prune_type = 'Batchnorm'
return prune_type

def detect_mask_prune_dim(masks, model):
"""
Detect how the masks of convolutional layers are pruned.

Parameters
----------
masks: dict
Expand Down
23 changes: 18 additions & 5 deletions nni/compression/pytorch/utils/shape_dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def reshape_break_channel_dependency(op_node):


class ChannelDependency(Dependency):
def __init__(self, model=None, dummy_input=None, traced_model=None):
def __init__(self, model=None, dummy_input=None, traced_model=None, prune_type='Filter'):
"""
This model analyze the channel dependencies between the conv
layers in a model.
Expand All @@ -98,7 +98,18 @@ def __init__(self, model=None, dummy_input=None, traced_model=None):
traced_model : torch._C.Graph
if we alreay has the traced graph of the target model, we donnot
need to trace the model again.
"""
prune_type: str
This parameter indicates the channel pruning type: 1) `Filter`
prune the filter of the convolution layer to prune the corresponding
channels 2) `Batchnorm`: prune the channel in the batchnorm layer
"""
self.prune_type = prune_type
self.target_types = []
if self.prune_type == 'Filter':
self.target_types.extend(['Conv2d', 'Linear', 'ConvTranspose2d'])
elif self.prune_type == 'Batchnorm':
self.target_types.append('BatchNorm2d')

super(ChannelDependency, self).__init__(
model, dummy_input, traced_model)

Expand All @@ -114,12 +125,13 @@ def _get_parent_layers(self, node):
parent_layers: list
nearest father conv/linear layers for the target worknode.
"""

parent_layers = []
queue = []
queue.append(node)
while queue:
curnode = queue.pop(0)
if curnode.op_type == 'Conv2d' or curnode.op_type == 'Linear' or curnode.op_type == 'ConvTranspose2d':
if curnode.op_type in self.target_types:
# find the first met conv
parent_layers.append(curnode.name)
continue
Expand All @@ -130,6 +142,7 @@ def _get_parent_layers(self, node):
parents = [self.graph.name_to_node[name] for name in parents]
for parent in parents:
queue.append(parent)

return parent_layers

def build_dependency(self):
Expand Down Expand Up @@ -193,7 +206,7 @@ def export(self, filepath):
csv_w = csv.writer(csvf, delimiter=',')
csv_w.writerow(header)
for node in self.graph.nodes_py.nodes_op:
if node.op_type != 'Conv2d' or node in visited:
if node.op_type not in self.target_types or node in visited:
continue
setid += 1
row = ['Set %d' % setid]
Expand All @@ -220,7 +233,7 @@ def dependency_sets(self):
d_sets = []
visited = set()
for node in self.graph.nodes_py.nodes_op:
if (node.op_type != 'Conv2d' and node.op_type != 'Linear') or node in visited:
if node.op_type not in self.target_types or node in visited:
continue
tmp_set = set()
if node.name not in self.dependency:
Expand Down