Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
compression speedup: add init file (#2063)
Browse files Browse the repository at this point in the history
  • Loading branch information
QuanluZhang authored Feb 15, 2020
1 parent b4ab371 commit b8c0fb6
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 8 deletions.
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT license.

import torch
from .infer_shape import CoarseMask, ModuleMasks
from .infer_shape import ModuleMasks

replace_module = {
'BatchNorm2d': lambda module, mask: replace_batchnorm2d(module, mask),
Expand Down
2 changes: 1 addition & 1 deletion src/sdk/pynni/nni/compression/speedup/torch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def _find_successors(self, module_name):
def infer_module_mask(self, module_name, mask=None, in_shape=None, out_shape=None):
"""
Infer input shape / output shape based on the module's weight mask / input shape / output shape.
For a module:
Infer its input and output shape from its weight mask
Infer its output shape from its input shape
Expand Down
11 changes: 5 additions & 6 deletions src/sdk/pynni/nni/compression/speedup/torch/infer_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def merge_index(index_a, index_b):
s.add(num)
for num in index_b:
s.add(num)
return torch.tensor(sorted(s))
return torch.tensor(sorted(s)) # pylint: disable=not-callable

def merge(self, cmask):
"""
Expand Down Expand Up @@ -98,7 +98,7 @@ def __init__(self, module_name):
self.param_masks = dict()
self.input_mask = None
self.output_mask = None

def set_param_masks(self, name, mask):
"""
Parameters
Expand Down Expand Up @@ -217,7 +217,7 @@ def view_inshape(module_masks, mask, shape):
TODO: consider replace tensor.view with nn.Flatten, because tensor.view is not
included in module, thus, cannot be replaced by our framework.
Parameters
----------
module_masks : ModuleMasks
Expand Down Expand Up @@ -250,7 +250,7 @@ def view_inshape(module_masks, mask, shape):
step_size = shape['in_shape'][2] * shape['in_shape'][3]
for loc in mask.mask_index[1]:
index.extend([loc * step_size + i for i in range(step_size)])
output_cmask.add_index_mask(dim=1, index=torch.tensor(index))
output_cmask.add_index_mask(dim=1, index=torch.tensor(index)) # pylint: disable=not-callable
module_masks.set_output_mask(output_cmask)
return output_cmask

Expand Down Expand Up @@ -373,7 +373,6 @@ def convert_to_coarse_mask(mask):
"""
assert 'weight' in mask
assert isinstance(mask['weight'], torch.Tensor)
cmask = None
weight_mask = mask['weight']
shape = weight_mask.size()
ones = torch.ones(shape[1:]).to(weight_mask.device)
Expand Down Expand Up @@ -451,7 +450,7 @@ def conv2d_outshape(module_masks, mask):
The ModuleMasks instance of the conv2d
mask : CoarseMask
The mask of its output tensor
Returns
-------
CoarseMask
Expand Down

0 comments on commit b8c0fb6

Please sign in to comment.