diff --git a/src/sdk/pynni/nni/compression/torch/builtin_pruners.py b/src/sdk/pynni/nni/compression/torch/builtin_pruners.py index 858db63a94..d0fa74e51b 100644 --- a/src/sdk/pynni/nni/compression/torch/builtin_pruners.py +++ b/src/sdk/pynni/nni/compression/torch/builtin_pruners.py @@ -2,7 +2,7 @@ import torch from .compressor import Pruner -__all__ = [ 'LevelPruner', 'AGP_Pruner', 'SensitivityPruner' ] +__all__ = ['LevelPruner', 'AGP_Pruner', 'SensitivityPruner'] logger = logging.getLogger('torch pruner') @@ -10,6 +10,7 @@ class LevelPruner(Pruner): """Prune to an exact pruning level specification """ + def __init__(self, config_list): """ config_list: supported keys: @@ -21,9 +22,9 @@ def calc_mask(self, weight, config, **kwargs): w_abs = weight.abs() k = int(weight.numel() * config['sparsity']) if k == 0: - return torch.ones(weight.shape) - threshold = torch.topk(w_abs.view(-1), k, largest = False).values.max() - return torch.gt(w_abs, threshold).type(weight.type()) + return torch.ones(weight.shape).type_as(weight) + threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max() + return torch.gt(w_abs, threshold).type_as(weight) class AGP_Pruner(Pruner): @@ -35,12 +36,13 @@ class AGP_Pruner(Pruner): Learning of Phones and other Consumer Devices, https://arxiv.org/pdf/1710.01878.pdf """ + def __init__(self, config_list): """ config_list: supported keys: - initial_sparsity - final_sparsity: you should make sure initial_sparsity <= final_sparsity - - start_epoch: start epoch numer begin update mask + - start_epoch: start epoch number begin update mask - end_epoch: end epoch number stop update mask, you should make sure start_epoch <= end_epoch - frequency: if you want update every 2 epoch, you can set it 2 """ @@ -49,15 +51,15 @@ def __init__(self, config_list): self.now_epoch = 1 def calc_mask(self, weight, config, op_name, **kwargs): - mask = self.mask_list.get(op_name, torch.ones(weight.shape)) + mask = self.mask_list.get(op_name, torch.ones(weight.shape).type_as(weight)) target_sparsity = self.compute_target_sparsity(config) k = int(weight.numel() * target_sparsity) if k == 0 or target_sparsity >= 1 or target_sparsity <= 0: return mask # if we want to generate new mask, we should update weigth first - w_abs = weight.abs()*mask - threshold = torch.topk(w_abs.view(-1), k, largest = False).values.max() - new_mask = torch.gt(w_abs, threshold).type(weight.type()) + w_abs = weight.abs() * mask + threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max() + new_mask = torch.gt(w_abs, threshold).type_as(weight) self.mask_list[op_name] = new_mask return new_mask @@ -74,18 +76,18 @@ def compute_target_sparsity(self, config): if end_epoch <= self.now_epoch: return final_sparsity - span = ((end_epoch - start_epoch-1)//freq)*freq + span = ((end_epoch - start_epoch - 1) // freq) * freq assert span > 0 - target_sparsity = (final_sparsity + - (initial_sparsity - final_sparsity)* - (1.0 - ((self.now_epoch - start_epoch)/span))**3) + target_sparsity = (final_sparsity + + (initial_sparsity - final_sparsity) * + (1.0 - ((self.now_epoch - start_epoch) / span)) ** 3) return target_sparsity def update_epoch(self, epoch): if epoch > 0: self.now_epoch = epoch - - + + class SensitivityPruner(Pruner): """Use algorithm from "Learning both Weights and Connections for Efficient Neural Networks" https://arxiv.org/pdf/1506.02626v3.pdf @@ -93,6 +95,7 @@ class SensitivityPruner(Pruner): I.e.: "The pruning threshold is chosen as a quality parameter multiplied by the standard deviation of a layers weights." """ + def __init__(self, config_list): """ config_list: supported keys: @@ -100,19 +103,18 @@ def __init__(self, config_list): """ super().__init__(config_list) self.mask_list = {} - - + def calc_mask(self, weight, config, op_name, **kwargs): - mask = self.mask_list.get(op_name, torch.ones(weight.shape)) - # if we want to generate new mask, we should update weigth first - weight = weight*mask + mask = self.mask_list.get(op_name, torch.ones(weight.shape).type_as(weight)) + # if we want to generate new mask, we should update weight first + weight = weight * mask target_sparsity = config['sparsity'] * torch.std(weight).item() k = int(weight.numel() * target_sparsity) if k == 0: return mask - + w_abs = weight.abs() - threshold = torch.topk(w_abs.view(-1), k, largest = False).values.max() - new_mask = torch.gt(w_abs, threshold).type(weight.type()) + threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max() + new_mask = torch.gt(w_abs, threshold).type_as(weight) self.mask_list[op_name] = new_mask return new_mask