Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
compression speedup: small code refactor (#2065)
Browse files Browse the repository at this point in the history
  • Loading branch information
QuanluZhang authored Feb 17, 2020
1 parent e6cedb8 commit 43de011
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 26 deletions.
11 changes: 7 additions & 4 deletions src/sdk/pynni/nni/compression/speedup/torch/compress_modules.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
import torch
from .infer_shape import ModuleMasks

_logger = logging.getLogger(__name__)

replace_module = {
'BatchNorm2d': lambda module, mask: replace_batchnorm2d(module, mask),
'Conv2d': lambda module, mask: replace_conv2d(module, mask),
Expand All @@ -16,6 +19,7 @@ def no_replace(module, mask):
"""
No need to replace
"""
_logger.debug("no need to replace")
return module

def replace_linear(linear, mask):
Expand All @@ -37,9 +41,8 @@ def replace_linear(linear, mask):
assert mask.output_mask is None
assert not mask.param_masks
index = mask.input_mask.mask_index[-1]
print(mask.input_mask.mask_index)
in_features = index.size()[0]
print('linear: ', in_features)
_logger.debug("replace linear with new in_features: %d", in_features)
new_linear = torch.nn.Linear(in_features=in_features,
out_features=linear.out_features,
bias=linear.bias is not None)
Expand Down Expand Up @@ -67,7 +70,7 @@ def replace_batchnorm2d(norm, mask):
assert 'weight' in mask.param_masks and 'bias' in mask.param_masks
index = mask.param_masks['weight'].mask_index[0]
num_features = index.size()[0]
print("replace batchnorm2d: ", num_features, index)
_logger.debug("replace batchnorm2d with num_features: %d", num_features)
new_norm = torch.nn.BatchNorm2d(num_features=num_features,
eps=norm.eps,
momentum=norm.momentum,
Expand Down Expand Up @@ -106,6 +109,7 @@ def replace_conv2d(conv, mask):
else:
out_channels_index = mask.output_mask.mask_index[1]
out_channels = out_channels_index.size()[0]
_logger.debug("replace conv2d with in_channels: %d, out_channels: %d", in_channels, out_channels)
new_conv = torch.nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=conv.kernel_size,
Expand All @@ -128,6 +132,5 @@ def replace_conv2d(conv, mask):
assert tmp_weight_data is not None, "Conv2d weight should be updated based on masks"
new_conv.weight.data.copy_(tmp_weight_data)
if conv.bias is not None:
print('final conv.bias is not None')
new_conv.bias.data.copy_(conv.bias.data if tmp_bias_data is None else tmp_bias_data)
return new_conv
39 changes: 17 additions & 22 deletions src/sdk/pynni/nni/compression/speedup/torch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node):
"""
# TODO: scope name could be empty
node_name = '.'.join([node.scopeName(), node.kind(), str(self.global_count)])
#print('node_name: ', node_name)
_logger.debug("expand non-prim node, node name: %s", node_name)
self.global_count += 1
op_type = node.kind()

Expand All @@ -173,7 +173,6 @@ def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node):
input_name = _input.debugName()
if input_name in output_to_node and output_to_node[input_name] in nodes:
predecessor_node = output_to_node[input_name]
#print("predecessor_node: ", predecessor_node)
if predecessor_node.kind().startswith('prim::'):
node_group.append(predecessor_node)
node_queue.put(predecessor_node)
Expand Down Expand Up @@ -231,7 +230,7 @@ def _build_graph(self):
"""
graph = self.trace_graph.graph
# if torch 1.4.0 is used, consider run torch._C._jit_pass_inline(graph) here
#print(graph)
#_logger.debug(graph)
# build output mapping, from output debugName to its node
output_to_node = dict()
# build input mapping, from input debugName to its node
Expand Down Expand Up @@ -301,10 +300,8 @@ def _build_graph(self):
m_inputs.append(_input)
elif not output_to_node[_input] in nodes:
m_inputs.append(_input)
print("module node_name: ", module_name)
if module_name == '':
for n in nodes:
print(n)
_logger.warning("module_name is empty string")
g_node = GNode(module_name, 'module', module_to_type[module_name], m_inputs, m_outputs, nodes)
self.g_nodes.append(g_node)

Expand Down Expand Up @@ -345,10 +342,7 @@ def _find_predecessors(self, module_name):
predecessors = []
for _input in self.name_to_gnode[module_name].inputs:
if not _input in self.output_to_gnode:
print(_input)
if not _input in self.output_to_gnode:
# TODO: check _input which does not have node
print("output with no gnode: ", _input)
_logger.debug("cannot find gnode with %s as its output", _input)
else:
g_node = self.output_to_gnode[_input]
predecessors.append(g_node.name)
Expand Down Expand Up @@ -407,15 +401,15 @@ def infer_module_mask(self, module_name, mask=None, in_shape=None, out_shape=Non
self.inferred_masks[module_name] = module_masks

m_type = self.name_to_gnode[module_name].op_type
print("infer_module_mask: {}, module type: {}".format(module_name, m_type))
_logger.debug("infer mask of module %s with op_type %s", module_name, m_type)
if mask is not None:
#print("mask is not None")
_logger.debug("mask is not None")
if not m_type in infer_from_mask:
raise RuntimeError("Has not supported infering \
input/output shape from mask for module/function: `{}`".format(m_type))
input_cmask, output_cmask = infer_from_mask[m_type](module_masks, mask)
if in_shape is not None:
#print("in_shape is not None")
_logger.debug("in_shape is not None")
if not m_type in infer_from_inshape:
raise RuntimeError("Has not supported infering \
output shape from input shape for module/function: `{}`".format(m_type))
Expand All @@ -426,23 +420,19 @@ def infer_module_mask(self, module_name, mask=None, in_shape=None, out_shape=Non
else:
output_cmask = infer_from_inshape[m_type](module_masks, in_shape)
if out_shape is not None:
#print("out_shape is not None")
_logger.debug("out_shape is not None")
if not m_type in infer_from_outshape:
raise RuntimeError("Has not supported infering \
input shape from output shape for module/function: `{}`".format(m_type))
input_cmask = infer_from_outshape[m_type](module_masks, out_shape)

if input_cmask:
#print("input_cmask is not None")
predecessors = self._find_predecessors(module_name)
for _module_name in predecessors:
print("input_cmask, module_name: ", _module_name)
self.infer_module_mask(_module_name, out_shape=input_cmask)
if output_cmask:
#print("output_cmask is not None")
successors = self._find_successors(module_name)
for _module_name in successors:
print("output_cmask, module_name: ", _module_name)
self.infer_module_mask(_module_name, in_shape=output_cmask)

def infer_modules_masks(self):
Expand All @@ -463,16 +453,19 @@ def replace_compressed_modules(self):
"""
for module_name in self.inferred_masks:
g_node = self.name_to_gnode[module_name]
print(module_name, g_node.op_type)
_logger.debug("replace %s, in %s type, with op_type %s",
module_name, g_node.type, g_node.op_type)
if g_node.type == 'module':
super_module, leaf_module = get_module_by_name(self.bound_model, module_name)
m_type = g_node.op_type
if not m_type in replace_module:
raise RuntimeError("Has not supported replacing the module: `{}`".format(m_type))
_logger.info("replace module (name: %s, op_type: %s)", module_name, m_type)
compressed_module = replace_module[m_type](leaf_module, self.inferred_masks[module_name])
setattr(super_module, module_name.split('.')[-1], compressed_module)
elif g_node.type == 'func':
print("Warning: Cannot replace func...")
_logger.info("Warning: cannot replace (name: %s, op_type: %s) which is func type",
module_name, g_node.op_type)
else:
raise RuntimeError("Unsupported GNode type: {}".format(g_node.type))

Expand All @@ -482,10 +475,12 @@ def speedup_model(self):
first, do mask/shape inference,
second, replace modules
"""
#print("start to compress")
_logger.info("start to speed up the model")
_logger.info("infer module masks...")
self.infer_modules_masks()
_logger.info("replace compressed modules...")
self.replace_compressed_modules()
#print("finished compressing")
_logger.info("speedup done")
# resume the model mode to that before the model is speed up
if self.is_training:
self.bound_model.train()
Expand Down

0 comments on commit 43de011

Please sign in to comment.