Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
fix builtin pruners bug (#1612)
Browse files Browse the repository at this point in the history
* fix builtin pruners bug
  • Loading branch information
tanglang96 authored and QuanluZhang committed Oct 16, 2019
1 parent d6b61e2 commit fd551c8
Showing 1 changed file with 25 additions and 23 deletions.
48 changes: 25 additions & 23 deletions src/sdk/pynni/nni/compression/torch/builtin_pruners.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import torch
from .compressor import Pruner

__all__ = [ 'LevelPruner', 'AGP_Pruner', 'SensitivityPruner' ]
__all__ = ['LevelPruner', 'AGP_Pruner', 'SensitivityPruner']

logger = logging.getLogger('torch pruner')


class LevelPruner(Pruner):
"""Prune to an exact pruning level specification
"""

def __init__(self, config_list):
"""
config_list: supported keys:
Expand All @@ -21,9 +22,9 @@ def calc_mask(self, weight, config, **kwargs):
w_abs = weight.abs()
k = int(weight.numel() * config['sparsity'])
if k == 0:
return torch.ones(weight.shape)
threshold = torch.topk(w_abs.view(-1), k, largest = False).values.max()
return torch.gt(w_abs, threshold).type(weight.type())
return torch.ones(weight.shape).type_as(weight)
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
return torch.gt(w_abs, threshold).type_as(weight)


class AGP_Pruner(Pruner):
Expand All @@ -35,12 +36,13 @@ class AGP_Pruner(Pruner):
Learning of Phones and other Consumer Devices,
https://arxiv.org/pdf/1710.01878.pdf
"""

def __init__(self, config_list):
"""
config_list: supported keys:
- initial_sparsity
- final_sparsity: you should make sure initial_sparsity <= final_sparsity
- start_epoch: start epoch numer begin update mask
- start_epoch: start epoch number begin update mask
- end_epoch: end epoch number stop update mask, you should make sure start_epoch <= end_epoch
- frequency: if you want update every 2 epoch, you can set it 2
"""
Expand All @@ -49,15 +51,15 @@ def __init__(self, config_list):
self.now_epoch = 1

def calc_mask(self, weight, config, op_name, **kwargs):
mask = self.mask_list.get(op_name, torch.ones(weight.shape))
mask = self.mask_list.get(op_name, torch.ones(weight.shape).type_as(weight))
target_sparsity = self.compute_target_sparsity(config)
k = int(weight.numel() * target_sparsity)
if k == 0 or target_sparsity >= 1 or target_sparsity <= 0:
return mask
# if we want to generate new mask, we should update weigth first
w_abs = weight.abs()*mask
threshold = torch.topk(w_abs.view(-1), k, largest = False).values.max()
new_mask = torch.gt(w_abs, threshold).type(weight.type())
w_abs = weight.abs() * mask
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
new_mask = torch.gt(w_abs, threshold).type_as(weight)
self.mask_list[op_name] = new_mask
return new_mask

Expand All @@ -74,45 +76,45 @@ def compute_target_sparsity(self, config):
if end_epoch <= self.now_epoch:
return final_sparsity

span = ((end_epoch - start_epoch-1)//freq)*freq
span = ((end_epoch - start_epoch - 1) // freq) * freq
assert span > 0
target_sparsity = (final_sparsity +
(initial_sparsity - final_sparsity)*
(1.0 - ((self.now_epoch - start_epoch)/span))**3)
target_sparsity = (final_sparsity +
(initial_sparsity - final_sparsity) *
(1.0 - ((self.now_epoch - start_epoch) / span)) ** 3)
return target_sparsity

def update_epoch(self, epoch):
if epoch > 0:
self.now_epoch = epoch


class SensitivityPruner(Pruner):
"""Use algorithm from "Learning both Weights and Connections for Efficient Neural Networks"
https://arxiv.org/pdf/1506.02626v3.pdf
I.e.: "The pruning threshold is chosen as a quality parameter multiplied
by the standard deviation of a layers weights."
"""

def __init__(self, config_list):
"""
config_list: supported keys:
- sparsity: chosen pruning sparsity
"""
super().__init__(config_list)
self.mask_list = {}



def calc_mask(self, weight, config, op_name, **kwargs):
mask = self.mask_list.get(op_name, torch.ones(weight.shape))
# if we want to generate new mask, we should update weigth first
weight = weight*mask
mask = self.mask_list.get(op_name, torch.ones(weight.shape).type_as(weight))
# if we want to generate new mask, we should update weight first
weight = weight * mask
target_sparsity = config['sparsity'] * torch.std(weight).item()
k = int(weight.numel() * target_sparsity)
if k == 0:
return mask

w_abs = weight.abs()
threshold = torch.topk(w_abs.view(-1), k, largest = False).values.max()
new_mask = torch.gt(w_abs, threshold).type(weight.type())
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
new_mask = torch.gt(w_abs, threshold).type_as(weight)
self.mask_list[op_name] = new_mask
return new_mask

0 comments on commit fd551c8

Please sign in to comment.