Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

compression speedup: small code refactor #2065

Merged
merged 41 commits into from
Feb 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
de817c2
update
QuanluZhang Jan 1, 2020
9e4a3d9
update
QuanluZhang Jan 10, 2020
1db91e3
Merge branch 'master' of https://github.com/Microsoft/nni into dev-co…
QuanluZhang Jan 13, 2020
e401f2b
update
QuanluZhang Jan 15, 2020
7ba30c3
update
QuanluZhang Jan 20, 2020
dc865fe
update
QuanluZhang Jan 21, 2020
10c0510
update
QuanluZhang Jan 21, 2020
df1dda7
update
QuanluZhang Jan 22, 2020
9680f3e
update
QuanluZhang Jan 24, 2020
e51f288
update
Jan 28, 2020
f830430
update
Jan 28, 2020
ab7f23d
update
Jan 28, 2020
98e75c2
pass eval result validate, but has very small difference
Jan 29, 2020
ff413d1
add model_speedup.py
Jan 29, 2020
d83f190
update
Jan 30, 2020
ff7e79d
pass fpgm test
Jan 31, 2020
e1240fe
add doc for speedup
Jan 31, 2020
8d333f2
pass l1filter
Jan 31, 2020
b1b2b14
update
Jan 31, 2020
e988f19
update
Feb 1, 2020
b8da18d
Merge branch 'dev-pruner-dataparallel' of https://github.com/microsof…
Feb 5, 2020
12485c7
Merge branch 'dev-pruner-dataparallel' of https://github.com/microsof…
Feb 5, 2020
fbb6d48
remove test files
Feb 5, 2020
1ce3c72
update
Feb 5, 2020
4db78f7
update
Feb 5, 2020
3d51727
update
Feb 5, 2020
70d3b1e
add comments
Feb 5, 2020
c80c7a9
add comments
Feb 6, 2020
005a664
add comments
Feb 6, 2020
d11a54a
add comments
Feb 6, 2020
49e0de1
update
Feb 6, 2020
951b014
resolve comments
Feb 8, 2020
280fb1b
update doc
Feb 10, 2020
c96f8b1
Merge branch 'v1.4' of https://github.com/microsoft/nni into dev-comp…
Feb 15, 2020
4c47da7
add init file
Feb 15, 2020
553879b
remove doc
Feb 15, 2020
61be340
fix pylint
Feb 15, 2020
6fe23aa
replace print with logging
Feb 15, 2020
e461c28
update
Feb 15, 2020
77cc67f
Merge branch 'v1.4' of https://github.com/microsoft/nni into dev-comp…
Feb 15, 2020
87f2da7
update
Feb 15, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/sdk/pynni/nni/compression/speedup/torch/compress_modules.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
import torch
from .infer_shape import ModuleMasks

_logger = logging.getLogger(__name__)

replace_module = {
'BatchNorm2d': lambda module, mask: replace_batchnorm2d(module, mask),
'Conv2d': lambda module, mask: replace_conv2d(module, mask),
Expand All @@ -16,6 +19,7 @@ def no_replace(module, mask):
"""
No need to replace
"""
_logger.debug("no need to replace")
return module

def replace_linear(linear, mask):
Expand All @@ -37,9 +41,8 @@ def replace_linear(linear, mask):
assert mask.output_mask is None
assert not mask.param_masks
index = mask.input_mask.mask_index[-1]
print(mask.input_mask.mask_index)
in_features = index.size()[0]
print('linear: ', in_features)
_logger.debug("replace linear with new in_features: %d", in_features)
new_linear = torch.nn.Linear(in_features=in_features,
out_features=linear.out_features,
bias=linear.bias is not None)
Expand Down Expand Up @@ -67,7 +70,7 @@ def replace_batchnorm2d(norm, mask):
assert 'weight' in mask.param_masks and 'bias' in mask.param_masks
index = mask.param_masks['weight'].mask_index[0]
num_features = index.size()[0]
print("replace batchnorm2d: ", num_features, index)
_logger.debug("replace batchnorm2d with num_features: %d", num_features)
new_norm = torch.nn.BatchNorm2d(num_features=num_features,
eps=norm.eps,
momentum=norm.momentum,
Expand Down Expand Up @@ -106,6 +109,7 @@ def replace_conv2d(conv, mask):
else:
out_channels_index = mask.output_mask.mask_index[1]
out_channels = out_channels_index.size()[0]
_logger.debug("replace conv2d with in_channels: %d, out_channels: %d", in_channels, out_channels)
new_conv = torch.nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=conv.kernel_size,
Expand All @@ -128,6 +132,5 @@ def replace_conv2d(conv, mask):
assert tmp_weight_data is not None, "Conv2d weight should be updated based on masks"
new_conv.weight.data.copy_(tmp_weight_data)
if conv.bias is not None:
print('final conv.bias is not None')
new_conv.bias.data.copy_(conv.bias.data if tmp_bias_data is None else tmp_bias_data)
return new_conv
39 changes: 17 additions & 22 deletions src/sdk/pynni/nni/compression/speedup/torch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node):
"""
# TODO: scope name could be empty
node_name = '.'.join([node.scopeName(), node.kind(), str(self.global_count)])
#print('node_name: ', node_name)
_logger.debug("expand non-prim node, node name: %s", node_name)
self.global_count += 1
op_type = node.kind()

Expand All @@ -173,7 +173,6 @@ def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node):
input_name = _input.debugName()
if input_name in output_to_node and output_to_node[input_name] in nodes:
predecessor_node = output_to_node[input_name]
#print("predecessor_node: ", predecessor_node)
if predecessor_node.kind().startswith('prim::'):
node_group.append(predecessor_node)
node_queue.put(predecessor_node)
Expand Down Expand Up @@ -231,7 +230,7 @@ def _build_graph(self):
"""
graph = self.trace_graph.graph
# if torch 1.4.0 is used, consider run torch._C._jit_pass_inline(graph) here
#print(graph)
#_logger.debug(graph)
# build output mapping, from output debugName to its node
output_to_node = dict()
# build input mapping, from input debugName to its node
Expand Down Expand Up @@ -301,10 +300,8 @@ def _build_graph(self):
m_inputs.append(_input)
elif not output_to_node[_input] in nodes:
m_inputs.append(_input)
print("module node_name: ", module_name)
if module_name == '':
for n in nodes:
print(n)
_logger.warning("module_name is empty string")
g_node = GNode(module_name, 'module', module_to_type[module_name], m_inputs, m_outputs, nodes)
self.g_nodes.append(g_node)

Expand Down Expand Up @@ -345,10 +342,7 @@ def _find_predecessors(self, module_name):
predecessors = []
for _input in self.name_to_gnode[module_name].inputs:
if not _input in self.output_to_gnode:
print(_input)
if not _input in self.output_to_gnode:
# TODO: check _input which does not have node
print("output with no gnode: ", _input)
_logger.debug("cannot find gnode with %s as its output", _input)
else:
g_node = self.output_to_gnode[_input]
predecessors.append(g_node.name)
Expand Down Expand Up @@ -407,15 +401,15 @@ def infer_module_mask(self, module_name, mask=None, in_shape=None, out_shape=Non
self.inferred_masks[module_name] = module_masks

m_type = self.name_to_gnode[module_name].op_type
print("infer_module_mask: {}, module type: {}".format(module_name, m_type))
_logger.debug("infer mask of module %s with op_type %s", module_name, m_type)
if mask is not None:
#print("mask is not None")
_logger.debug("mask is not None")
if not m_type in infer_from_mask:
raise RuntimeError("Has not supported infering \
input/output shape from mask for module/function: `{}`".format(m_type))
input_cmask, output_cmask = infer_from_mask[m_type](module_masks, mask)
if in_shape is not None:
#print("in_shape is not None")
_logger.debug("in_shape is not None")
if not m_type in infer_from_inshape:
raise RuntimeError("Has not supported infering \
output shape from input shape for module/function: `{}`".format(m_type))
Expand All @@ -426,23 +420,19 @@ def infer_module_mask(self, module_name, mask=None, in_shape=None, out_shape=Non
else:
output_cmask = infer_from_inshape[m_type](module_masks, in_shape)
if out_shape is not None:
#print("out_shape is not None")
_logger.debug("out_shape is not None")
if not m_type in infer_from_outshape:
raise RuntimeError("Has not supported infering \
input shape from output shape for module/function: `{}`".format(m_type))
input_cmask = infer_from_outshape[m_type](module_masks, out_shape)

if input_cmask:
#print("input_cmask is not None")
predecessors = self._find_predecessors(module_name)
for _module_name in predecessors:
print("input_cmask, module_name: ", _module_name)
self.infer_module_mask(_module_name, out_shape=input_cmask)
if output_cmask:
#print("output_cmask is not None")
successors = self._find_successors(module_name)
for _module_name in successors:
print("output_cmask, module_name: ", _module_name)
self.infer_module_mask(_module_name, in_shape=output_cmask)

def infer_modules_masks(self):
Expand All @@ -463,16 +453,19 @@ def replace_compressed_modules(self):
"""
for module_name in self.inferred_masks:
g_node = self.name_to_gnode[module_name]
print(module_name, g_node.op_type)
_logger.debug("replace %s, in %s type, with op_type %s",
module_name, g_node.type, g_node.op_type)
if g_node.type == 'module':
super_module, leaf_module = get_module_by_name(self.bound_model, module_name)
m_type = g_node.op_type
if not m_type in replace_module:
raise RuntimeError("Has not supported replacing the module: `{}`".format(m_type))
_logger.info("replace module (name: %s, op_type: %s)", module_name, m_type)
compressed_module = replace_module[m_type](leaf_module, self.inferred_masks[module_name])
setattr(super_module, module_name.split('.')[-1], compressed_module)
elif g_node.type == 'func':
print("Warning: Cannot replace func...")
_logger.info("Warning: cannot replace (name: %s, op_type: %s) which is func type",
module_name, g_node.op_type)
else:
raise RuntimeError("Unsupported GNode type: {}".format(g_node.type))

Expand All @@ -482,10 +475,12 @@ def speedup_model(self):
first, do mask/shape inference,
second, replace modules
"""
#print("start to compress")
_logger.info("start to speed up the model")
_logger.info("infer module masks...")
self.infer_modules_masks()
_logger.info("replace compressed modules...")
self.replace_compressed_modules()
#print("finished compressing")
_logger.info("speedup done")
# resume the model mode to that before the model is speed up
if self.is_training:
self.bound_model.train()
Expand Down