Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Error code for Speedup Module #4173

Merged
merged 5 commits into from
Sep 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 47 additions & 27 deletions nni/compression/pytorch/speedup/compress_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import torch
import torch.nn as nn
from .error_code import EmptyLayerError, ShapeMisMatchError, InputsNumberError, OutputTypeError, UnBalancedGroupError

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -44,7 +45,6 @@
}



def convert_to_coarse_mask(t_mask, dim):
"""
Convert the mask tensor to the coarse-grained mask tensor.
Expand Down Expand Up @@ -87,6 +87,7 @@ def no_replace(module, masks):
_logger.debug("no need to replace")
return module


def replace_prelu(prelu, masks):
"""
Parameters
Expand All @@ -102,8 +103,11 @@ def replace_prelu(prelu, masks):
The new prelu module
"""
in_masks, output_mask, weight_mask = masks
assert len(in_masks) == 1
assert isinstance(output_mask, torch.Tensor)
if len(in_masks) != 1:
raise InputsNumberError()
if not isinstance(output_mask, torch.Tensor):
raise OutputTypeError(type(output_mask), torch.Tensor)

in_mask = in_masks[0]
weight_mask = weight_mask['weight']
pruned_in, remained_in = convert_to_coarse_mask(in_mask, 1)
Expand All @@ -112,13 +116,17 @@ def replace_prelu(prelu, masks):
n_remained_out = weight_mask.size(0) - pruned_out.size(0)
remained_in, remained_out = remained_in.to(
prelu.weight.device), remained_out.to(prelu.weight.device)
assert n_remained_in == n_remained_out
if n_remained_in != n_remained_out:
raise ShapeMisMatchError()

if n_remained_in == 0:
return torch.nn.Identity()
new_prelu = torch.nn.PReLU(n_remained_in)
new_prelu.weight.data = torch.index_select(prelu.weight.data, 0, remained_in)
new_prelu.weight.data = torch.index_select(
prelu.weight.data, 0, remained_in)
return new_prelu


def replace_linear(linear, masks):
"""
This function will replace the original linear according to
Expand All @@ -142,8 +150,11 @@ def replace_linear(linear, masks):
"""
in_masks, output_mask, weight_mask = masks
assert isinstance(linear, nn.Linear)
assert len(in_masks) == 1
assert isinstance(output_mask, torch.Tensor)
if len(in_masks) != 1:
raise InputsNumberError()
if not isinstance(output_mask, torch.Tensor):
raise OutputTypeError(type(output_mask), torch.Tensor)

in_mask = in_masks[0]

weight_mask = weight_mask['weight']
Expand Down Expand Up @@ -199,7 +210,8 @@ def replace_batchnorm1d(norm, masks):
# N, C, H, W
_, remained_in = convert_to_coarse_mask(in_mask, 1)
_, remained_out = convert_to_coarse_mask(output_mask, 1)
assert remained_in.size(0) == remained_out.size(0)
if remained_in.size(0) != remained_out.size(0):
raise ShapeMisMatchError()

num_features = remained_in.size(0)
_logger.info("replace batchnorm1d with num_features: %d", num_features)
Expand Down Expand Up @@ -241,7 +253,8 @@ def replace_batchnorm2d(norm, masks):
# N, C, H, W
_, remained_in = convert_to_coarse_mask(in_mask, 1)
_, remained_out = convert_to_coarse_mask(output_mask, 1)
assert remained_in.size(0) == remained_out.size(0)
if remained_in.size(0) != remained_out.size(0):
raise ShapeMisMatchError()

num_features = remained_in.size(0)
_logger.info("replace batchnorm2d with num_features: %d", num_features)
Expand All @@ -261,7 +274,6 @@ def replace_batchnorm2d(norm, masks):
return new_norm



def replace_conv2d(conv, masks):
"""
Replace the original conv with a new one according to the infered
Expand All @@ -285,7 +297,8 @@ def replace_conv2d(conv, masks):
in_masks, output_mask, weight_masks = masks
assert isinstance(conv, nn.Conv2d)
# the conv layer should only have one input tensor
assert len(in_masks) == 1
if len(in_masks) != 1:
raise InputsNumberError()

in_mask = in_masks[0]

Expand All @@ -296,8 +309,8 @@ def replace_conv2d(conv, masks):
n_remained_in = weight_mask.size(1) * conv.groups - pruned_in.size(0)
n_remained_out = weight_mask.size(0) - pruned_out.size(0)

assert n_remained_in == remained_in.size(0)
assert n_remained_out == remained_out.size(0)
if n_remained_in != remained_in.size(0) or n_remained_out != remained_out.size(0):
raise ShapeMisMatchError()

k_size1, k_size2 = conv.kernel_size
# Note: We should resolve the group dependency of the conv layers before
Expand Down Expand Up @@ -331,9 +344,10 @@ def replace_conv2d(conv, masks):
tmp_weight = torch.ones(
n_remained_out, new_inchannel_step, k_size1, k_size2)
tmp_weight = tmp_weight.to(conv.weight.device)

assert n_remained_in % new_inchannel_step == 0
assert n_remained_out % new_outchannel_step == 0
if new_inchannel_step == 0 or new_outchannel_step == 0:
raise EmptyLayerError()
if n_remained_in % new_inchannel_step != 0 or n_remained_out % new_outchannel_step != 0:
raise UnBalancedGroupError()

new_groups = 0
for groupid in range(conv.groups):
Expand All @@ -352,8 +366,9 @@ def replace_conv2d(conv, masks):
assert len(current_output_index) == 0
continue
# check if the number of remained channel of each group are the same
assert len(current_input_index) == new_inchannel_step
assert len(current_output_index) == new_outchannel_step
if len(current_input_index) != new_inchannel_step or len(current_output_index) != new_outchannel_step:
raise UnBalancedGroupError()

# copy the weight into tmp_weight
new_out_start = new_outchannel_step * new_groups
new_out_end = new_out_start + new_outchannel_step
Expand Down Expand Up @@ -386,7 +401,6 @@ def replace_conv2d(conv, masks):
new_conv.bias.data.copy_(torch.index_select(
conv.bias.data, 0, remained_out))


return new_conv


Expand All @@ -410,7 +424,8 @@ def replace_convtranspose2d(convtrans, masks):
"""
in_masks, output_mask, weight_masks = masks
assert isinstance(convtrans, torch.nn.ConvTranspose2d)
assert len(in_masks) == 1
if len(in_masks) != 1:
raise InputsNumberError()
in_mask = in_masks[0]

weight_mask = weight_masks['weight']
Expand All @@ -420,8 +435,9 @@ def replace_convtranspose2d(convtrans, masks):
n_remained_in = weight_mask.size(0) - pruned_in.size(0)
n_remained_out = weight_mask.size(
1) * convtrans.groups - pruned_out.size(0)
assert n_remained_in == remained_in.size(0)
assert n_remained_out == remained_out.size(0)
if n_remained_in != remained_in.size(0) or n_remained_out != remained_out.size(0):
raise ShapeMisMatchError()

k_size1, k_size2 = convtrans.kernel_size
# Note: we should resolve the group dependency of the convtrans layers before
# run into this function
Expand All @@ -448,8 +464,10 @@ def replace_convtranspose2d(convtrans, masks):
n_remained_in, new_outchannel_step, k_size1, k_size2)
tmp_weight = tmp_weight.to(convtrans.weight.device)

assert n_remained_in % new_inchannel_step == 0
assert n_remained_out % new_outchannel_step == 0
if new_inchannel_step == 0 or new_outchannel_step == 0:
raise EmptyLayerError()
if n_remained_in % new_inchannel_step != 0 or n_remained_out % new_outchannel_step != 0:
raise UnBalancedGroupError()

new_groups = 0
for groupid in range(convtrans.groups):
Expand All @@ -471,8 +489,9 @@ def replace_convtranspose2d(convtrans, masks):
assert len(current_output_index) == 0
continue
# check if the number of remained channel of each group are the same
assert len(current_input_index) == new_inchannel_step
assert len(current_output_index) == new_outchannel_step
if len(current_input_index) != new_inchannel_step or len(current_output_index) != new_outchannel_step:
raise UnBalancedGroupError()

# copy the weight into tmp_weight
new_in_start = new_inchannel_step * new_groups
new_in_end = new_in_start + new_inchannel_step
Expand Down Expand Up @@ -505,7 +524,8 @@ def replace_convtranspose2d(convtrans, masks):
def replace_layernorm(layernorm, masks):
in_masks, _, _ = masks
assert isinstance(layernorm, nn.LayerNorm)
assert len(in_masks) == 1
if len(in_masks) != 1:
raise InputsNumberError()
in_mask = in_masks[0]
dim_n = len(in_mask.size())
new_shape = []
Expand Down
1 change: 1 addition & 0 deletions nni/compression/pytorch/speedup/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .jit_translate import jit_to_python_function
from ..utils import rand_like_with_shape


_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)

Expand Down
31 changes: 31 additions & 0 deletions nni/compression/pytorch/speedup/error_code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Error Code of the speedup
class SpeedupError(Exception):
def __init__(self, msg):
self.msg = msg

def __str__(self):
return str(self.msg)

class EmptyLayerError(SpeedupError):
def __init__(self):
super(EmptyLayerError, self).__init__("Pruning a Layer to empty is not legal")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this situation, should we suggest the user add this layer in exclude?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exactly the same, there may be two reasons that a layer is pruned to empty: (1) it's sparsity ratio equals to 1.0 (2) it's output is useless after the mask propagation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it


class ShapeMisMatchError(SpeedupError):
def __init__(self):
super(ShapeMisMatchError, self).__init__("Shape mismatch!")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can also add layer name in the error message for debugging?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion, but I cannot get the op_name in the module replace functions under current interfaces. To keep the replacement function interface clean, I suggest to just report the error type now.


class InputsNumberError(SpeedupError):
def __init__(self):
super(InputsNumberError, self).__init__("The number of the inputs of the target OP is wrong")

class OutputTypeError(SpeedupError):
def __init__(self, current_type, target_type):
msg = f"The output type should be {str(target_type)}, but {str(current_type)} founded"
super(OutputTypeError, self).__init__(msg)

class UnBalancedGroupError(SpeedupError):
def __init__(self):
msg = "The number remained filters in each group is different"
super(UnBalancedGroupError, self).__init__(msg)